1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/cpu/ir_emitter.h"
17
18 #include <stddef.h>
19 #include <stdint.h>
20 #include <algorithm>
21 #include <iterator>
22 #include <limits>
23 #include <memory>
24 #include <utility>
25 #include <vector>
26
27 // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
28 #include "absl/container/flat_hash_map.h"
29 #include "absl/container/flat_hash_set.h"
30 #include "absl/strings/str_cat.h"
31 #include "absl/strings/str_format.h"
32 #include "absl/types/span.h"
33 #include "llvm/CodeGen/TargetRegisterInfo.h"
34 #include "llvm/CodeGen/TargetSubtargetInfo.h"
35 #include "llvm/IR/BasicBlock.h"
36 #include "llvm/IR/Constants.h"
37 #include "llvm/IR/GlobalVariable.h"
38 #include "llvm/IR/Instructions.h"
39 #include "llvm/IR/Intrinsics.h"
40 #include "llvm/IR/LLVMContext.h"
41 #include "tensorflow/compiler/xla/layout_util.h"
42 #include "tensorflow/compiler/xla/map_util.h"
43 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
44 #include "tensorflow/compiler/xla/service/cpu/cpu_options.h"
45 #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
46 #include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h"
47 #include "tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h"
48 #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
49 #include "tensorflow/compiler/xla/service/cpu/ir_function.h"
50 #include "tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h"
51 #include "tensorflow/compiler/xla/service/cpu/shape_partition.h"
52 #include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h"
53 #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
54 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
55 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
56 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
57 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
58 #include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
59 #include "tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h"
60 #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
61 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
62 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
63 #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h"
64 #include "tensorflow/compiler/xla/shape_util.h"
65 #include "tensorflow/compiler/xla/status_macros.h"
66 #include "tensorflow/compiler/xla/types.h"
67 #include "tensorflow/compiler/xla/util.h"
68 #include "tensorflow/compiler/xla/window_util.h"
69 #include "tensorflow/core/lib/core/bits.h"
70 #include "tensorflow/core/lib/core/errors.h"
71 #include "tensorflow/core/lib/math/math_util.h"
72 #include "tensorflow/core/platform/logging.h"
73
74 namespace xla {
75
76 namespace {
77 using llvm_ir::IrName;
78 using llvm_ir::SetToFirstInsertPoint;
79 } // namespace
80
81 namespace cpu {
82
IrEmitter(const HloModule & hlo_module,const BufferAssignment & assignment,llvm::Module * llvm_module,std::unordered_map<const HloInstruction *,int64> instruction_to_profile_idx,std::unordered_map<const HloComputation *,int64> computation_to_profile_idx,const TargetMachineFeatures * target_machine_features,bool emit_code_for_msan)83 IrEmitter::IrEmitter(
84 const HloModule& hlo_module, const BufferAssignment& assignment,
85 llvm::Module* llvm_module,
86 std::unordered_map<const HloInstruction*, int64> instruction_to_profile_idx,
87 std::unordered_map<const HloComputation*, int64> computation_to_profile_idx,
88 const TargetMachineFeatures* target_machine_features,
89 bool emit_code_for_msan)
90 : assignment_(assignment),
91 module_(llvm_module),
92 arch_type_(llvm::Triple(llvm_module->getTargetTriple()).getArch()),
93 b_(llvm_module->getContext()),
94 instruction_to_profile_idx_(std::move(instruction_to_profile_idx)),
95 computation_to_profile_idx_(std::move(computation_to_profile_idx)),
96 alias_analysis_(hlo_module, assignment, &llvm_module->getContext()),
97 hlo_module_config_(hlo_module.config()),
98 is_top_level_computation_(false),
99 target_machine_features_(*target_machine_features),
100 emit_code_for_msan_(emit_code_for_msan) {
101 b_.setFastMathFlags(llvm_ir::GetCpuFastMathFlags(hlo_module_config_));
102 Status s = GatherComputationsByAllocationType(
103 &hlo_module, &thread_local_computations_, &global_computations_);
104 absl::c_sort(thread_local_computations_);
105 absl::c_sort(global_computations_);
106 TF_CHECK_OK(s) << "Should have failed buffer assignment.";
107 }
108
EmitComputation(HloComputation * computation,const string & function_name_prefix,bool is_top_level_computation,absl::Span<HloInstruction * const> instruction_order)109 StatusOr<llvm::Function*> IrEmitter::EmitComputation(
110 HloComputation* computation, const string& function_name_prefix,
111 bool is_top_level_computation,
112 absl::Span<HloInstruction* const> instruction_order) {
113 string function_name = name_uniquer_.GetUniqueName(function_name_prefix);
114 VLOG(2) << "Emitting IR for CPU function [" << function_name_prefix << "]";
115 is_top_level_computation_ = is_top_level_computation;
116 num_dynamic_loop_bounds_ = 0;
117 if (!computation->root_instruction()->outer_dimension_partitions().empty()) {
118 num_dynamic_loop_bounds_ =
119 computation->root_instruction()->outer_dimension_partitions().size();
120 }
121
122 if (computation->root_instruction()->opcode() != HloOpcode::kOutfeed) {
123 TF_ASSIGN_OR_RETURN(
124 computation_root_allocation_,
125 assignment_.GetUniqueTopLevelSlice(computation->root_instruction()));
126 }
127
128 for (const HloInstruction* param : computation->parameter_instructions()) {
129 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice param_slice,
130 assignment_.GetUniqueTopLevelSlice(param));
131 computation_parameter_allocations_[param_slice.allocation()->index()] =
132 param->parameter_number();
133 }
134
135 InitializeIrFunction(function_name);
136 // The rdtscp instruction is x86 specific. We will fallback to LLVM's generic
137 // readcyclecounter if it is unavailable.
138 bool use_rdtscp = arch_type_ == llvm::Triple::ArchType::x86 ||
139 arch_type_ == llvm::Triple::ArchType::x86_64;
140 profiling_state_ = ProfilingState(use_rdtscp);
141 TF_RETURN_IF_ERROR(computation->AcceptOrdered(this, instruction_order));
142 llvm::Function* ir_function = compute_function_->function();
143 InsertOrDie(&emitted_functions_, computation, ir_function);
144 // Delete 'compute_function', finalizing 'ir_function' and restoring caller
145 // IR insert point.
146 compute_function_.reset();
147 computation_root_allocation_ = BufferAllocation::Slice();
148 computation_parameter_allocations_.clear();
149 return ir_function;
150 }
151
InitializeIrFunction(const string & function_name)152 void IrEmitter::InitializeIrFunction(const string& function_name) {
153 // Functions with local linkage get an inlining bonus. Because we know
154 // a-priori that embedded functions (non-entry functions) will not have its
155 // name resolved, give it local linkage.
156 llvm::Function::LinkageTypes linkage =
157 is_top_level_computation_ ? llvm::GlobalValue::ExternalLinkage
158 : llvm::GlobalValue::InternalLinkage;
159 // Create and initialize new IrFunction.
160 compute_function_.reset(new IrFunction(function_name, linkage,
161 hlo_module_config_, module_, &b_,
162 num_dynamic_loop_bounds_));
163 }
164
~IrEmitter()165 IrEmitter::~IrEmitter() {}
166
HandleBitcast(HloInstruction * bitcast)167 Status IrEmitter::HandleBitcast(HloInstruction* bitcast) {
168 VLOG(2) << "HandleBitcast: " << bitcast->ToString();
169 emitted_value_[bitcast] =
170 BitCast(GetEmittedValueFor(bitcast->operand(0)),
171 IrShapeType(bitcast->shape())->getPointerTo(), IrName(bitcast));
172 return Status::OK();
173 }
174
EmitGlobalForLiteral(const Literal & literal)175 llvm::Constant* IrEmitter::EmitGlobalForLiteral(const Literal& literal) {
176 llvm::Constant* initializer =
177 llvm_ir::ConvertLiteralToIrConstant(literal, module_);
178 llvm::GlobalVariable* result_global = new llvm::GlobalVariable(
179 /*Module=*/*module_,
180 /*Type=*/initializer->getType(),
181 /*isConstant=*/true,
182 /*Linkage=*/llvm::GlobalValue::PrivateLinkage,
183 /*Initializer=*/initializer,
184 /*Name=*/"");
185 result_global->setAlignment(MinimumAlignmentForShape(literal.shape()));
186 result_global->setUnnamedAddr(llvm::GlobalVariable::UnnamedAddr::Global);
187 return llvm::ConstantExpr::getBitCast(
188 result_global, IrShapeType(literal.shape())->getPointerTo());
189 }
190
EmitConstantGlobals()191 Status IrEmitter::EmitConstantGlobals() {
192 for (const BufferAllocation& allocation : assignment_.Allocations()) {
193 if (!allocation.is_constant()) {
194 continue;
195 }
196
197 const Literal& literal = llvm_ir::LiteralForConstantAllocation(allocation);
198 llvm::Constant* global_for_const;
199 auto it = emitted_literals_.find(&literal);
200 if (it != emitted_literals_.end()) {
201 global_for_const = it->second;
202 } else {
203 global_for_const = EmitGlobalForLiteral(literal);
204 InsertOrDie(&emitted_literals_, &literal, global_for_const);
205 }
206
207 InsertOrDie(&constant_buffer_to_global_, allocation.index(),
208 global_for_const);
209 }
210
211 return Status::OK();
212 }
213
HandleConstant(HloInstruction * constant)214 Status IrEmitter::HandleConstant(HloInstruction* constant) {
215 VLOG(2) << "HandleConstant: " << constant->ToString();
216 // IrEmitter::EmitConstantGlobals has already taken care of emitting the body
217 // of the constant.
218 return EmitTargetAddressForOp(constant);
219 }
220
HandleCopy(HloInstruction * copy)221 Status IrEmitter::HandleCopy(HloInstruction* copy) {
222 if (copy->shape().IsTuple()) {
223 // kCopy shallow copies a tuple so just memcpy the top-level buffer.
224 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(copy));
225 return EmitMemcpy(*(copy->operand(0)), *copy);
226 } else if (copy->shape().IsArray()) {
227 // Use the elemental emitter for array shapes.
228 return DefaultAction(copy);
229 }
230 return Unimplemented("unsupported operand type %s for copy instruction",
231 PrimitiveType_Name(copy->shape().element_type()));
232 }
233
234 // Calculate the alignment of a buffer allocated for a given primitive type.
MinimumAlignmentForPrimitiveType(PrimitiveType primitive_type)235 int IrEmitter::MinimumAlignmentForPrimitiveType(PrimitiveType primitive_type) {
236 int64 byte_size = ShapeUtil::ByteSizeOfPrimitiveType(primitive_type);
237 DCHECK_GE(byte_size, 0);
238 // Largest scalar is a complex128 so we don't need to worry about the
239 // int64->int truncation here.
240 DCHECK_LE(byte_size, 16);
241
242 // Allocations may be 8-byte aligned if part of a small block.
243 return std::min(8LL, byte_size);
244 }
245
ByteSizeOf(const Shape & shape) const246 int64 IrEmitter::ByteSizeOf(const Shape& shape) const {
247 return llvm_ir::ByteSizeOf(shape, module_->getDataLayout());
248 }
249
250 // Calculate the alignment of a buffer allocated for a given shape.
MinimumAlignmentForShape(const Shape & shape)251 int IrEmitter::MinimumAlignmentForShape(const Shape& shape) {
252 if (ShapeUtil::IsScalar(shape)) {
253 return MinimumAlignmentForPrimitiveType(shape.element_type());
254 }
255
256 int64 buffer_size = ByteSizeOf(shape);
257 DCHECK_GE(buffer_size, 0);
258 DCHECK_LE(buffer_size, SIZE_MAX);
259
260 return target_machine_features_.minimum_alignment_for_allocation(buffer_size);
261 }
262
AttachAlignmentMetadataForLoad(llvm::LoadInst * load,const Shape & shape)263 void IrEmitter::AttachAlignmentMetadataForLoad(llvm::LoadInst* load,
264 const Shape& shape) {
265 int alignment = MinimumAlignmentForShape(shape);
266 if (alignment > 1) {
267 llvm_ir::SetAlignmentMetadataForLoad(load, alignment);
268 }
269 }
270
AttachAlignmentMetadataForLoad(llvm::LoadInst * load,int64 buffer_size)271 void IrEmitter::AttachAlignmentMetadataForLoad(llvm::LoadInst* load,
272 int64 buffer_size) {
273 int alignment =
274 target_machine_features_.minimum_alignment_for_allocation(buffer_size);
275 if (alignment > 1) {
276 llvm_ir::SetAlignmentMetadataForLoad(load, alignment);
277 }
278 }
279
AttachDereferenceableMetadataForLoad(llvm::LoadInst * load,const Shape & shape)280 void IrEmitter::AttachDereferenceableMetadataForLoad(llvm::LoadInst* load,
281 const Shape& shape) {
282 AttachDereferenceableMetadataForLoad(load, ByteSizeOf(shape));
283 }
284
AttachDereferenceableMetadataForLoad(llvm::LoadInst * load,int64 buffer_size)285 void IrEmitter::AttachDereferenceableMetadataForLoad(llvm::LoadInst* load,
286 int64 buffer_size) {
287 if (buffer_size > 0) {
288 llvm_ir::SetDereferenceableMetadataForLoad(load, buffer_size);
289 }
290 }
291
HandleGetTupleElement(HloInstruction * get_tuple_element)292 Status IrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element) {
293 // A tuple is an array of pointers, one for each operand. Each pointer points
294 // to the output buffer of its corresponding operand. A GetTupleElement
295 // instruction forwards a pointer to the tuple element buffer at the given
296 // index.
297 auto operand = get_tuple_element->operand(0);
298 const Shape& shape = get_tuple_element->shape();
299 emitted_value_[get_tuple_element] = llvm_ir::EmitGetTupleElement(
300 shape, get_tuple_element->tuple_index(), MinimumAlignmentForShape(shape),
301 GetEmittedValueFor(operand), &b_);
302 return Status::OK();
303 }
304
HandleSelect(HloInstruction * select)305 Status IrEmitter::HandleSelect(HloInstruction* select) {
306 auto pred = select->operand(0);
307 TF_RET_CHECK(pred->shape().element_type() == PRED);
308 return DefaultAction(select);
309 }
310
HandleTupleSelect(HloInstruction * tuple_select)311 Status IrEmitter::HandleTupleSelect(HloInstruction* tuple_select) {
312 auto pred = tuple_select->operand(0);
313 auto on_true = tuple_select->operand(1);
314 auto on_false = tuple_select->operand(2);
315 TF_RET_CHECK(pred->shape().element_type() == PRED);
316 TF_RET_CHECK(ShapeUtil::IsScalar(pred->shape()));
317 TF_RET_CHECK(tuple_select->shape().IsTuple());
318 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(tuple_select));
319 llvm_ir::EmitTupleSelect(GetIrArrayFor(tuple_select), GetIrArrayFor(pred),
320 GetEmittedValueFor(on_true),
321 GetEmittedValueFor(on_false), &b_);
322 return Status::OK();
323 }
324
HandleInfeed(HloInstruction * instruction)325 Status IrEmitter::HandleInfeed(HloInstruction* instruction) {
326 HloInfeedInstruction* infeed = Cast<HloInfeedInstruction>(instruction);
327 VLOG(2) << "HandleInfeed: " << infeed->ToString();
328
329 // The infeed operation produces a two-element tuple containing data and a
330 // token value. HloInfeedInstruction::infeed_shape gives us the data shape.
331 const Shape& data_shape = infeed->infeed_shape();
332 DCHECK(ShapeUtil::Equal(data_shape,
333 ShapeUtil::GetTupleElementShape(infeed->shape(), 0)));
334 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(infeed));
335
336 // Write the tuple index table.
337 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice data_slice,
338 assignment_.GetUniqueSlice(infeed, {0}));
339 llvm::Value* data_address = EmitBufferPointer(data_slice, data_shape);
340 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice token_slice,
341 assignment_.GetUniqueSlice(infeed, {1}));
342 llvm::Value* token_address = EmitBufferPointer(
343 token_slice, ShapeUtil::GetTupleElementShape(infeed->shape(), 1));
344 llvm_ir::EmitTuple(GetIrArrayFor(infeed), {data_address, token_address}, &b_);
345
346 if (data_shape.IsTuple()) {
347 TF_RET_CHECK(!ShapeUtil::IsNestedTuple(data_shape));
348
349 // For a tuple, we first copy each of the internal elements to
350 // their corresponding target locations. We then construct the
351 // tuple outer buffer containing pointers to the internal
352 // elements.
353 std::vector<llvm::Value*> tuple_element_addresses;
354 for (int64 i = 0; i < data_shape.tuple_shapes_size(); ++i) {
355 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice buffer,
356 assignment_.GetUniqueSlice(infeed, {0, i}));
357
358 const Shape& tuple_element_shape =
359 ShapeUtil::GetTupleElementShape(data_shape, i);
360
361 // Only the outer tuple buffer's target address is obtained from
362 // GetEmittedValueFor, to handle the case when Infeed is the root
363 // instruction. Target addresses for internal elements can be obtained
364 // from EmitBufferPointer.
365 llvm::Value* tuple_element_address =
366 EmitBufferPointer(buffer, tuple_element_shape);
367
368 TF_RETURN_IF_ERROR(EmitXfeedTransfer(
369 XfeedKind::kInfeed, tuple_element_shape, tuple_element_address));
370
371 tuple_element_addresses.push_back(tuple_element_address);
372 }
373
374 llvm_ir::EmitTuple(llvm_ir::IrArray(data_address, data_shape),
375 tuple_element_addresses, &b_);
376 } else {
377 TF_RETURN_IF_ERROR(
378 EmitXfeedTransfer(XfeedKind::kInfeed, data_shape, data_address));
379 }
380
381 return Status::OK();
382 }
383
EmitXfeedTransfer(XfeedKind kind,const Shape & shape,llvm::Value * program_buffer_address)384 Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape,
385 llvm::Value* program_buffer_address) {
386 int64 length = ByteSizeOf(shape);
387 if (length <= 0 || length > std::numeric_limits<int32>::max()) {
388 return InvalidArgument(
389 "xfeed (infeed or outfeed) buffer length %d is outside the valid "
390 "size range",
391 length);
392 }
393 int32 length_32 = static_cast<int32>(length);
394
395 int32 shape_length;
396 TF_ASSIGN_OR_RETURN(
397 llvm::Value * shape_ptr,
398 llvm_ir::EncodeSelfDescribingShapeConstant(shape, &shape_length, &b_));
399
400 llvm::Type* int32_type = b_.getInt32Ty();
401 llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext());
402 llvm::FunctionType* acquire_type = llvm::FunctionType::get(
403 i8_ptr_type,
404 {/*run_options*/ i8_ptr_type, /*buffer_length*/ int32_type,
405 /*shape_ptr*/ i8_ptr_type, /*shape_length*/ int32_type},
406 /*isVarArg=*/false);
407
408 llvm::Function* acquire_func;
409 if (kind == XfeedKind::kInfeed) {
410 acquire_func = llvm::dyn_cast<llvm::Function>(
411 module_
412 ->getOrInsertFunction(
413 runtime::kAcquireInfeedBufferForDequeueSymbolName, acquire_type)
414 .getCallee());
415 } else {
416 acquire_func = llvm::dyn_cast<llvm::Function>(
417 module_
418 ->getOrInsertFunction(
419 runtime::kAcquireOutfeedBufferForPopulationSymbolName,
420 acquire_type)
421 .getCallee());
422 }
423 acquire_func->setCallingConv(llvm::CallingConv::C);
424
425 llvm::FunctionType* release_type = llvm::FunctionType::get(
426 b_.getVoidTy(),
427 {/*run_options*/ i8_ptr_type, /*buffer_length*/ int32_type,
428 /*buffer_ptr*/ i8_ptr_type, /*shape_ptr*/ i8_ptr_type,
429 /*shape_length*/ int32_type},
430 /*isVarArg=*/false);
431
432 llvm::Function* release_func;
433 if (kind == XfeedKind::kInfeed) {
434 release_func = llvm::dyn_cast<llvm::Function>(
435 module_
436 ->getOrInsertFunction(
437 runtime::kReleaseInfeedBufferAfterDequeueSymbolName,
438 release_type)
439 .getCallee());
440 } else {
441 release_func = llvm::dyn_cast<llvm::Function>(
442 module_
443 ->getOrInsertFunction(
444 runtime::kReleaseOutfeedBufferAfterPopulationSymbolName,
445 release_type)
446 .getCallee());
447 }
448 release_func->setCallingConv(llvm::CallingConv::C);
449
450 // Implementation note: this call informs the runtime that it wants a buffer
451 // of size exactly 'length_32', and the runtime is responsible for
452 // check-failing the process if there is a mismatch, versus passing us back a
453 // buffer that we might overrun.
454 llvm::Value* acquired_pointer = Call(
455 acquire_func, {GetExecutableRunOptionsArgument(), b_.getInt32(length_32),
456 shape_ptr, b_.getInt32(shape_length)});
457
458 if (kind == XfeedKind::kInfeed) {
459 // Copy to the program buffer address from the acquired buffer.
460 MemCpy(program_buffer_address, /*DstAlign=*/1, acquired_pointer,
461 /*SrcAlign=*/1, length_32);
462 } else {
463 // Outfeed -- copy from the in-program address to the acquired buffer.
464 MemCpy(acquired_pointer, /*DstAlign=*/1, program_buffer_address,
465 /*SrcAlign=*/1, length_32);
466 }
467
468 Call(release_func, {GetExecutableRunOptionsArgument(), b_.getInt32(length_32),
469 acquired_pointer, shape_ptr, b_.getInt32(shape_length)});
470
471 return Status::OK();
472 }
473
HandleOutfeed(HloInstruction * outfeed)474 Status IrEmitter::HandleOutfeed(HloInstruction* outfeed) {
475 // Outfeed produces no useful result, but it does return a token[] that can be
476 // threaded through to other side effecting operations to ensure ordering. In
477 // the IR emitter we treat this token as a normal u8[] and thus need to insert
478 // an entry for it in emitted_value_.
479 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(outfeed));
480
481 HloInstruction* operand = outfeed->operands()[0];
482 const Shape& operand_shape = operand->shape();
483
484 llvm::Value* value = GetEmittedValueFor(operand);
485 if (!operand_shape.IsTuple()) {
486 return EmitXfeedTransfer(XfeedKind::kOutfeed, operand_shape, value);
487 }
488
489 TF_RET_CHECK(!ShapeUtil::IsNestedTuple(operand_shape));
490
491 for (int64 i = 0; i < operand_shape.tuple_shapes_size(); ++i) {
492 const Shape& tuple_element_shape =
493 ShapeUtil::GetTupleElementShape(operand_shape, i);
494 llvm::Value* tuple_element = llvm_ir::EmitGetTupleElement(
495 tuple_element_shape, i, MinimumAlignmentForShape(tuple_element_shape),
496 value, &b_);
497 TF_RETURN_IF_ERROR(EmitXfeedTransfer(XfeedKind::kOutfeed,
498 tuple_element_shape, tuple_element));
499 }
500
501 return Status::OK();
502 }
503
HandleSort(HloInstruction * hlo)504 Status IrEmitter::HandleSort(HloInstruction* hlo) {
505 const HloSortInstruction* sort = Cast<HloSortInstruction>(hlo);
506 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(sort));
507 Shape keys_shape = sort->keys()->shape();
508 PrimitiveType keys_type = keys_shape.element_type();
509 switch (keys_type) {
510 case PRED:
511 case S8:
512 case U8:
513 case S16:
514 case U16:
515 case BF16:
516 case F16:
517 case S32:
518 case U32:
519 case F32:
520 case S64:
521 case U64:
522 case F64:
523 break;
524 default:
525 return Unimplemented(
526 "Element type %s not supported in the Sort op on CPU.",
527 PrimitiveType_Name(keys_type));
528 }
529 std::vector<llvm::Value*> destination_addresses(sort->operand_count());
530 for (int64 i = 0; i < sort->operand_count(); ++i) {
531 ShapeIndex shape_index =
532 sort->values_count() > 0 ? ShapeIndex({i}) : ShapeIndex({});
533 const HloInstruction* operand = sort->operand(i);
534 // We assume that the layout of all involved operands and outputs is the
535 // same.
536 TF_RET_CHECK(
537 LayoutUtil::LayoutsInShapesEqual(keys_shape, operand->shape()));
538 TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual(
539 keys_shape, ShapeUtil::GetSubshape(sort->shape(), shape_index)));
540
541 // The sort is implemented in-place, therefore we first copy the operand
542 // buffer to the output buffer if they are not the same.
543 auto destination_buffer = GetAllocationSlice(*sort, shape_index);
544 destination_addresses[i] =
545 EmitBufferPointer(destination_buffer, operand->shape());
546 auto source_address = GetAllocationSlice(*operand);
547 if (destination_buffer != source_address) {
548 int64 primitive_type_size =
549 ShapeUtil::ByteSizeOfPrimitiveType(operand->shape().element_type());
550 auto source_buffer = GetEmittedValueFor(operand);
551 int64 size = ByteSizeOf(operand->shape());
552 MemCpy(destination_addresses[i], /*DstAlign=*/primitive_type_size,
553 source_buffer,
554 /*SrcAlign=*/primitive_type_size, size);
555 }
556 }
557
558 // Normalize the shape and the dimension to sort.
559 Shape normalized_keys_shape =
560 ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(keys_shape);
561 int64 physical_dimension_to_sort = LayoutUtil::MakeLogicalToPhysical(
562 keys_shape.layout())[sort->sort_dimension()];
563
564 int64 sort_dimension_elements =
565 normalized_keys_shape.dimensions(physical_dimension_to_sort);
566 int64 higher_dimensions = 1;
567 for (int64 i = 0; i < physical_dimension_to_sort; ++i) {
568 higher_dimensions *= normalized_keys_shape.dimensions(i);
569 }
570 int64 lower_dimensions = 1;
571 for (int64 i = normalized_keys_shape.rank() - 1;
572 i > physical_dimension_to_sort; --i) {
573 lower_dimensions *= normalized_keys_shape.dimensions(i);
574 }
575
576 auto less_than_function = FindOrDie(emitted_functions_, sort->to_apply());
577 CHECK(absl::c_binary_search(thread_local_computations_, sort->to_apply()));
578 llvm::FunctionType* key_value_sort_type = llvm::FunctionType::get(
579 b_.getVoidTy(),
580 {b_.getInt64Ty(), b_.getInt64Ty(), b_.getInt64Ty(),
581 b_.getInt8PtrTy()->getPointerTo(), b_.getInt32Ty(),
582 b_.getInt32Ty()->getPointerTo(), b_.getInt1Ty(), b_.getInt8PtrTy(),
583 b_.getInt64Ty()->getPointerTo(), less_than_function->getType()},
584 /*isVarArg=*/false);
585 auto* key_value_sort_func = llvm::dyn_cast<llvm::Function>(
586 module_
587 ->getOrInsertFunction(runtime::kKeyValueSortSymbolName,
588 key_value_sort_type)
589 .getCallee());
590 key_value_sort_func->setCallingConv(llvm::CallingConv::C);
591 key_value_sort_func->setDoesNotThrow();
592 llvm::Value* values = llvm_ir::EmitAllocaAtFunctionEntryWithCount(
593 b_.getInt8PtrTy(), b_.getInt32(sort->operand_count()), "cc_values_alloca",
594 &b_);
595 llvm::Value* sizes = llvm_ir::EmitAllocaAtFunctionEntryWithCount(
596 b_.getInt32Ty(), b_.getInt32(sort->operand_count()), "cc_sizes_alloca",
597 &b_);
598 for (int64 i = 0; i < sort->operand_count(); ++i) {
599 llvm::Value* value_as_i8ptr =
600 PointerCast(destination_addresses[i], b_.getInt8PtrTy());
601 llvm::Value* slot_in_values_alloca =
602 ConstInBoundsGEP1_32(b_.getInt8PtrTy(), values, i);
603 Store(value_as_i8ptr, slot_in_values_alloca);
604 llvm::Value* slot_in_sizes_alloca =
605 ConstInBoundsGEP1_32(b_.getInt32Ty(), sizes, i);
606 llvm::Value* size = b_.getInt32(ShapeUtil::ByteSizeOfPrimitiveType(
607 sort->operand(i)->shape().element_type()));
608 Store(size, slot_in_sizes_alloca);
609 }
610
611 Call(key_value_sort_func,
612 {b_.getInt64(higher_dimensions), b_.getInt64(sort_dimension_elements),
613 b_.getInt64(lower_dimensions), values,
614 b_.getInt32(sort->operand_count()), sizes,
615 b_.getInt1(sort->is_stable()), GetExecutableRunOptionsArgument(),
616 GetProfileCountersArgument(), less_than_function});
617
618 if (sort->values_count() > 0) {
619 llvm_ir::EmitTuple(GetIrArrayFor(sort), destination_addresses, &b_);
620 }
621 return Status::OK();
622 }
623
HandleTuple(HloInstruction * tuple)624 Status IrEmitter::HandleTuple(HloInstruction* tuple) {
625 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(tuple));
626 std::vector<llvm::Value*> base_ptrs;
627 for (auto operand : tuple->operands()) {
628 base_ptrs.push_back(GetEmittedValueFor(operand));
629 }
630 llvm_ir::EmitTuple(GetIrArrayFor(tuple), base_ptrs, &b_);
631 return Status::OK();
632 }
633
EmitElementalMap(const HloMapInstruction & map_instr,absl::Span<llvm::Value * const> elemental_operands,absl::string_view name)634 llvm::Value* IrEmitter::EmitElementalMap(
635 const HloMapInstruction& map_instr,
636 absl::Span<llvm::Value* const> elemental_operands, absl::string_view name) {
637 return EmitThreadLocalCall(*map_instr.to_apply(), elemental_operands, name);
638 }
639
EmitElementalReduceWindow(const HloReduceWindowInstruction * reduce_window,const llvm_ir::ElementGenerator & input_generator,const llvm_ir::IrArray::Index & index)640 StatusOr<llvm::Value*> IrEmitter::EmitElementalReduceWindow(
641 const HloReduceWindowInstruction* reduce_window,
642 const llvm_ir::ElementGenerator& input_generator,
643 const llvm_ir::IrArray::Index& index) {
644 const HloInstruction* operand = reduce_window->operand(0);
645 const Window& window = reduce_window->window();
646
647 // We fold inputs into the accumulator and initialize it to
648 // the initial value on the reduce_window.
649 PrimitiveType operand_element_type = operand->shape().element_type();
650 llvm::Value* accumulator_address = llvm_ir::EmitAllocaAtFunctionEntry(
651 llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_),
652 "reduce_window_accumulator_address", &b_,
653 MinimumAlignmentForPrimitiveType(operand_element_type));
654 Store(Load(GetEmittedValueFor(reduce_window->operand(1))),
655 accumulator_address);
656
657 llvm_ir::ForLoopNest loops(IrName(reduce_window, "inner"), &b_);
658 std::vector<int64> window_size;
659 for (const auto& dim : window.dimensions()) {
660 window_size.push_back(dim.size());
661 }
662 const llvm_ir::IrArray::Index window_index = loops.AddLoopsForShape(
663 ShapeUtil::MakeShape(operand_element_type, window_size), "window");
664 CHECK_EQ(window_index.size(), index.size());
665
666 SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_);
667
668 std::vector<llvm::Value*> input_multi_index(index.size());
669 llvm::Value* in_bounds_condition = nullptr;
670 for (size_t i = 0; i < index.size(); ++i) {
671 llvm::Value* strided_index =
672 NSWMul(index[i], b_.getInt64(window.dimensions(i).stride()));
673 input_multi_index[i] = NSWSub(
674 NSWAdd(strided_index,
675 NSWMul(window_index[i],
676 b_.getInt64(window.dimensions(i).window_dilation()))),
677 b_.getInt64(window.dimensions(i).padding_low()));
678
679 // We need to verify that we are not in the dilated base area.
680 llvm::Value* dilation_condition =
681 ICmpEQ(SRem(input_multi_index[i],
682 b_.getInt64(window.dimensions(i).base_dilation())),
683 b_.getInt64(0));
684 if (in_bounds_condition == nullptr) {
685 in_bounds_condition = dilation_condition;
686 } else {
687 in_bounds_condition = And(in_bounds_condition, dilation_condition);
688 }
689
690 // Apply base dilation to the index.
691 input_multi_index[i] =
692 SDiv(input_multi_index[i],
693 b_.getInt64(window.dimensions(i).base_dilation()));
694
695 // We need to check if 0 <= input_multi_index[i] < bound, as otherwise we
696 // are in the padding so that we can skip the computation. That is
697 // equivalent to input_multi_index[i] < bound as an *unsigned* comparison,
698 // since a negative value will wrap to a large positive value.
699 llvm::Value* index_condition =
700 ICmpULT(input_multi_index[i],
701 b_.getInt64(ShapeUtil::GetDimension(operand->shape(), i)));
702 if (in_bounds_condition == nullptr) {
703 in_bounds_condition = index_condition;
704 } else {
705 in_bounds_condition = And(in_bounds_condition, index_condition);
706 }
707 }
708 CHECK(in_bounds_condition != nullptr);
709
710 llvm_ir::LlvmIfData if_data =
711 llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &b_);
712 SetToFirstInsertPoint(if_data.true_block, &b_);
713
714 // We are not in the padding, so carry out the computation.
715 llvm_ir::IrArray::Index input_index(input_multi_index, operand->shape(),
716 b_.getInt64Ty());
717 TF_ASSIGN_OR_RETURN(llvm::Value* const input_value,
718 input_generator(input_index));
719 llvm::Value* result = EmitThreadLocalCall(
720 *reduce_window->to_apply(), {Load(accumulator_address), input_value},
721 "reducer_function");
722 Store(result, accumulator_address);
723
724 SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
725 return Load(accumulator_address);
726 }
727
HandleReduceWindow(HloInstruction * reduce_window)728 Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) {
729 // Pseudo code for reduce window:
730 //
731 // for (coordinates O in the output)
732 // value = init_value;
733 // for (coordinates W in the window)
734 // for each index i:
735 // input coordinates I_i = O_i * stride_i + W_i - pad_low_i
736 // if I within bounds of input:
737 // value = function(value, input(I));
738 // output(O) = value;
739 //
740 // This is completely un-optimized and just here to have something
741 // that works.
742 return DefaultAction(reduce_window);
743 }
744
HandleSelectAndScatter(HloInstruction * select_and_scatter)745 Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) {
746 CHECK_EQ(select_and_scatter->operand_count(), 3);
747 const auto operand = select_and_scatter->operand(0);
748 const auto source = select_and_scatter->operand(1);
749 const auto init_value = select_and_scatter->operand(2);
750 const Window& window = select_and_scatter->window();
751 PrimitiveType operand_element_type = operand->shape().element_type();
752 const int64 rank = operand->shape().rank();
753 CHECK_EQ(rank, source->shape().rank());
754 CHECK_EQ(rank, window.dimensions_size());
755
756 // TODO(b/31410564): Implement dilation for select-and-scatter.
757 if (window_util::HasDilation(window)) {
758 return Unimplemented(
759 "Dilation for SelectAndScatter is not implemented on CPU. ");
760 }
761
762 // Pseudo code for select-and-scatter:
763 //
764 // initialized_flag is initially off for every window, and is turned on after
765 // the first iteration is completed and the first operand value is selected.
766 //
767 // output(*) = init_value
768 // for (coordinates S in the source) {
769 // initialized_flag = false
770 // for (coordinates W in the window) {
771 // I = S * stride + W - pad_low
772 // if I within bounds of operand:
773 // if !initialized_flag or select(selected_value, operand(I)) == false:
774 // selected_value = operand(I)
775 // selected_index = I
776 // initialized_flag = true
777 // }
778 // output(selected_index) = scatter(output(selected_index), source(S))
779 // }
780 //
781
782 // Initialize the output array with the given init_value.
783 TF_RETURN_IF_ERROR(EmitTargetElementLoop(
784 select_and_scatter, /*desc=*/IrName(select_and_scatter, "init"),
785 [this, init_value](const llvm_ir::IrArray::Index& target_index) {
786 llvm::Value* init_value_addr = GetEmittedValueFor(init_value);
787 return Load(init_value_addr);
788 }));
789
790 // Create a loop to iterate over the source array to scatter to the output.
791 llvm_ir::ForLoopNest source_loops(IrName(select_and_scatter), &b_);
792 const llvm_ir::IrArray::Index source_index =
793 source_loops.AddLoopsForShape(source->shape(), "source");
794 SetToFirstInsertPoint(source_loops.GetInnerLoopBodyBasicBlock(), &b_);
795
796 // Allocate space to keep the currently selected value, its index, and
797 // the boolean initialized_flag, which is initially set to false.
798 llvm::Value* selected_value_address = llvm_ir::EmitAllocaAtFunctionEntry(
799 llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_),
800 "selected_value_address", &b_,
801 MinimumAlignmentForPrimitiveType(operand_element_type));
802 llvm::Value* selected_index_address =
803 llvm_ir::EmitAllocaAtFunctionEntryWithCount(
804 b_.getInt64Ty(), b_.getInt32(rank), "selected_index_address", &b_);
805 llvm::Value* initialized_flag_address = llvm_ir::EmitAllocaAtFunctionEntry(
806 b_.getInt1Ty(), "initialized_flag_address", &b_);
807 Store(b_.getInt1(false), initialized_flag_address);
808
809 // Create the inner loop to iterate over the window.
810 llvm_ir::ForLoopNest window_loops(IrName(select_and_scatter, "window"), &b_);
811 std::vector<int64> window_size;
812 for (const auto& dim : window.dimensions()) {
813 window_size.push_back(dim.size());
814 }
815 const llvm_ir::IrArray::Index window_index = window_loops.AddLoopsForShape(
816 ShapeUtil::MakeShape(operand_element_type, window_size), "window");
817 SetToFirstInsertPoint(window_loops.GetInnerLoopBodyBasicBlock(), &b_);
818
819 // Compute the operand index to visit and evaluate the condition whether the
820 // operand index is within the bounds. The unsigned comparison includes
821 // checking whether the operand index >= 0.
822 std::vector<llvm::Value*> operand_multi_index(source_index.size());
823 llvm::Value* in_bounds_condition = b_.getTrue();
824 for (int64 i = 0; i < rank; ++i) {
825 llvm::Value* strided_index =
826 NSWMul(source_index[i], b_.getInt64(window.dimensions(i).stride()));
827 operand_multi_index[i] =
828 NSWSub(NSWAdd(strided_index, window_index[i]),
829 b_.getInt64(window.dimensions(i).padding_low()));
830 llvm::Value* index_condition =
831 ICmpULT(operand_multi_index[i],
832 b_.getInt64(ShapeUtil::GetDimension(operand->shape(), i)));
833 in_bounds_condition = And(in_bounds_condition, index_condition);
834 }
835 CHECK(in_bounds_condition != nullptr);
836
837 // Only need to do something if the operand index is within the bounds. First
838 // check if the initialized_flag is set.
839 llvm_ir::LlvmIfData if_in_bounds =
840 llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &b_);
841 SetToFirstInsertPoint(if_in_bounds.true_block, &b_);
842 llvm_ir::LlvmIfData if_initialized = llvm_ir::EmitIfThenElse(
843 Load(initialized_flag_address), "initialized", &b_);
844
845 // If the initialized_flag is false, initialize the selected value and index
846 // with the currently visiting operand.
847 SetToFirstInsertPoint(if_initialized.false_block, &b_);
848 const auto save_operand_index =
849 [&](const llvm_ir::IrArray::Index& operand_index) {
850 for (int64 i = 0; i < rank; ++i) {
851 llvm::Value* selected_index_address_slot =
852 InBoundsGEP(selected_index_address, {b_.getInt32(i)});
853 Store(operand_index[i], selected_index_address_slot);
854 }
855 };
856 llvm_ir::IrArray operand_array(GetIrArrayFor(operand));
857 llvm_ir::IrArray::Index operand_index(
858 operand_multi_index, operand_array.GetShape(), b_.getInt64Ty());
859 llvm::Value* operand_data =
860 operand_array.EmitReadArrayElement(operand_index, &b_);
861 Store(operand_data, selected_value_address);
862 save_operand_index(operand_index);
863 Store(b_.getInt1(true), initialized_flag_address);
864
865 // If the initialized_flag is true, call the `select` function to potentially
866 // update the selected value and index with the currently visiting operand.
867 SetToFirstInsertPoint(if_initialized.true_block, &b_);
868 llvm::Value* operand_address =
869 operand_array.EmitArrayElementAddress(operand_index, &b_);
870 llvm::Value* operand_element = Load(operand_address);
871 llvm::Value* result = EmitThreadLocalCall(
872 *select_and_scatter->select(),
873 {Load(selected_value_address), operand_element}, "select_function");
874
875 // If the 'select' function returns false, update the selected value and the
876 // index to the currently visiting operand.
877 llvm::Value* cond = ICmpNE(
878 result,
879 llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0),
880 "boolean_predicate");
881 llvm_ir::LlvmIfData if_select_lhs =
882 llvm_ir::EmitIfThenElse(cond, "if-select-lhs", &b_);
883 SetToFirstInsertPoint(if_select_lhs.false_block, &b_);
884 Store(Load(operand_address), selected_value_address);
885 save_operand_index(operand_index);
886
887 // After iterating over the window elements, scatter the source element to
888 // the selected index of the output. The value we store at the output
889 // location is computed by calling the `scatter` function with the source
890 // value and the current output value.
891 SetToFirstInsertPoint(window_loops.GetOuterLoopExitBasicBlock(), &b_);
892 std::vector<llvm::Value*> selected_multi_index;
893 for (int64 i = 0; i < rank; ++i) {
894 llvm::Value* selected_index_address_slot =
895 InBoundsGEP(selected_index_address, {b_.getInt32(i)});
896 selected_multi_index.push_back(Load(selected_index_address_slot));
897 }
898 llvm_ir::IrArray source_array(GetIrArrayFor(source));
899 llvm::Value* source_value =
900 source_array.EmitReadArrayElement(source_index, &b_);
901 llvm_ir::IrArray output_array(GetIrArrayFor(select_and_scatter));
902 llvm_ir::IrArray::Index selected_index(
903 selected_multi_index, output_array.GetShape(), source_index.GetType());
904 llvm::Value* output_value =
905 output_array.EmitReadArrayElement(selected_index, &b_);
906 llvm::Value* scatter_value =
907 EmitThreadLocalCall(*select_and_scatter->scatter(),
908 {output_value, source_value}, "scatter_function");
909 output_array.EmitWriteArrayElement(selected_index, scatter_value, &b_);
910
911 SetToFirstInsertPoint(source_loops.GetOuterLoopExitBasicBlock(), &b_);
912 return Status::OK();
913 }
914
HandleDot(HloInstruction * dot)915 Status IrEmitter::HandleDot(HloInstruction* dot) {
916 auto lhs = dot->operand(0);
917 auto rhs = dot->operand(1);
918 TF_RETURN_IF_ERROR(ElementTypesSameAndSupported(
919 /*instruction=*/*dot, /*operands=*/{lhs, rhs},
920 /*supported_types=*/{F16, F32, F64, C64, C128}));
921 const DotDimensionNumbers& dnums = dot->dot_dimension_numbers();
922
923 if (dnums.lhs_contracting_dimensions_size() != 1) {
924 // This is disallowed by ShapeInference today.
925 return Unimplemented(
926 "Dot with multiple contracting dimensions not implemented.");
927 }
928
929 llvm_ir::IrArray lhs_array(GetIrArrayFor(lhs));
930 llvm_ir::IrArray rhs_array(GetIrArrayFor(rhs));
931
932 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(dot));
933 llvm_ir::IrArray target_array = GetIrArrayFor(dot);
934
935 VLOG(2) << "HandleDot: ";
936 VLOG(2) << " lhs operand: "
937 << llvm_ir::DumpToString(*lhs_array.GetBasePointer());
938 VLOG(2) << " rhs operand: "
939 << llvm_ir::DumpToString(*rhs_array.GetBasePointer());
940 VLOG(2) << " target: "
941 << llvm_ir::DumpToString(*target_array.GetBasePointer());
942
943 // Dot operation is complicated so we delegate to a helper class.
944 return EmitDotOperation(*dot, target_array, lhs_array, rhs_array,
945 /*addend_array=*/nullptr,
946 GetExecutableRunOptionsArgument(), &b_,
947 hlo_module_config_, target_machine_features_);
948 }
949
EmitElementalConvolution(const HloConvolutionInstruction * convolution,const llvm_ir::ElementGenerator & input_generator,const llvm_ir::ElementGenerator & kernel_generator,const llvm_ir::IrArray::Index & index)950 StatusOr<llvm::Value*> IrEmitter::EmitElementalConvolution(
951 const HloConvolutionInstruction* convolution,
952 const llvm_ir::ElementGenerator& input_generator,
953 const llvm_ir::ElementGenerator& kernel_generator,
954 const llvm_ir::IrArray::Index& index) {
955 const HloInstruction* lhs = convolution->operand(0);
956 const HloInstruction* rhs = convolution->operand(1);
957 const Window& window = convolution->window();
958
959 const ConvolutionDimensionNumbers& dnums =
960 convolution->convolution_dimension_numbers();
961 int num_spatial_dims = dnums.output_spatial_dimensions_size();
962 std::vector<llvm::Value*> output_spatial(num_spatial_dims);
963 for (int i = 0; i < num_spatial_dims; ++i) {
964 output_spatial[i] = index[dnums.output_spatial_dimensions(i)];
965 }
966 llvm::Value* output_feature = index[dnums.output_feature_dimension()];
967 llvm::Value* batch = index[dnums.output_batch_dimension()];
968
969 // We will accumulate the products into this sum to calculate the output entry
970 // at the given index.
971 PrimitiveType lhs_element_type = lhs->shape().element_type();
972 llvm::Type* lhs_llvm_type =
973 llvm_ir::PrimitiveTypeToIrType(lhs_element_type, module_);
974 llvm::Value* sum_address = llvm_ir::EmitAllocaAtFunctionEntry(
975 lhs_llvm_type, "convolution_sum_address", &b_,
976 MinimumAlignmentForPrimitiveType(lhs_element_type));
977 llvm::Value* constant_zero = llvm::Constant::getNullValue(lhs_llvm_type);
978 Store(constant_zero, sum_address);
979
980 llvm_ir::ForLoopNest loops(IrName(convolution, "inner"), &b_);
981 std::vector<llvm::Value*> kernel_spatial(num_spatial_dims);
982 for (int i = 0; i < num_spatial_dims; ++i) {
983 kernel_spatial[i] =
984 loops
985 .AddLoop(
986 0, rhs->shape().dimensions(dnums.kernel_spatial_dimensions(i)),
987 absl::StrCat("k", i))
988 ->GetIndVarValue();
989 }
990 llvm::Value* input_feature =
991 loops
992 .AddLoop(0, lhs->shape().dimensions(dnums.input_feature_dimension()),
993 "iz")
994 ->GetIndVarValue();
995
996 SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_);
997
998 // Calculate the spatial index in the input array, taking striding, dilation
999 // and padding into account. An index in the padding will be out of the bounds
1000 // of the array.
1001 const auto calculate_input_index = [this](llvm::Value* output_index,
1002 llvm::Value* kernel_index,
1003 const WindowDimension& window_dim) {
1004 llvm::Value* strided_index =
1005 NSWMul(output_index, b_.getInt64(window_dim.stride()));
1006 llvm::Value* dilated_kernel_index =
1007 NSWMul(kernel_index, b_.getInt64(window_dim.window_dilation()));
1008 return NSWSub(NSWAdd(strided_index, dilated_kernel_index),
1009 b_.getInt64(window_dim.padding_low()));
1010 };
1011 std::vector<llvm::Value*> input_spatial(num_spatial_dims);
1012 for (int i = 0; i < num_spatial_dims; ++i) {
1013 input_spatial[i] = calculate_input_index(
1014 output_spatial[i], kernel_spatial[i], window.dimensions(i));
1015 }
1016
1017 // We need to check if 0 <= input dim < bound, as otherwise we are in the
1018 // padding so that we can skip the computation. That is equivalent to input
1019 // dim < bound as an *unsigned* comparison, since a negative value will wrap
1020 // to a large positive value. The input dim is dilated, so we need to dilate
1021 // the bound as well to match.
1022
1023 // Also need to check that the input coordinates are not in one of the
1024 // holes created by base dilation.
1025 const auto not_in_hole = [&](llvm::Value* input_index, int64 base_dilation) {
1026 llvm::Value* remainder = SRem(input_index, b_.getInt64(base_dilation));
1027 return ICmpEQ(remainder, b_.getInt64(0));
1028 };
1029
1030 llvm::Value* in_bounds_condition = b_.getInt1(true);
1031 for (int i = 0; i < num_spatial_dims; ++i) {
1032 llvm::ConstantInt* input_bound = b_.getInt64(window_util::DilatedBound(
1033 lhs->shape().dimensions(dnums.input_spatial_dimensions(i)),
1034 window.dimensions(i).base_dilation()));
1035 llvm::Value* dim_in_bound = ICmpULT(input_spatial[i], input_bound);
1036 llvm::Value* dim_not_in_hole =
1037 not_in_hole(input_spatial[i], window.dimensions(i).base_dilation());
1038 llvm::Value* dim_ok = And(dim_in_bound, dim_not_in_hole);
1039 in_bounds_condition = And(in_bounds_condition, dim_ok);
1040 }
1041
1042 // Now we need to map the dilated base coordinates back to the actual
1043 // data indices on the lhs.
1044 const auto undilate = [&](llvm::Value* input_index, int64 base_dilation) {
1045 return SDiv(input_index, b_.getInt64(base_dilation));
1046 };
1047 for (int i = 0; i < num_spatial_dims; ++i) {
1048 input_spatial[i] =
1049 undilate(input_spatial[i], window.dimensions(i).base_dilation());
1050 }
1051
1052 llvm_ir::LlvmIfData if_data =
1053 llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &b_);
1054 SetToFirstInsertPoint(if_data.true_block, &b_);
1055
1056 // We are not in the padding, so carry out the computation.
1057 int num_dims = num_spatial_dims + 2;
1058 std::vector<llvm::Value*> input_multi_index(num_dims);
1059 for (int i = 0; i < num_spatial_dims; ++i) {
1060 input_multi_index[dnums.input_spatial_dimensions(i)] = input_spatial[i];
1061 }
1062 input_multi_index[dnums.input_feature_dimension()] = input_feature;
1063 input_multi_index[dnums.input_batch_dimension()] = batch;
1064
1065 std::vector<llvm::Value*> kernel_multi_index(num_dims);
1066 for (int i = 0; i < num_spatial_dims; ++i) {
1067 kernel_multi_index[dnums.kernel_spatial_dimensions(i)] =
1068 window.dimensions(i).window_reversal()
1069 ? NSWSub(b_.getInt64(window.dimensions(i).size() - 1),
1070 kernel_spatial[i])
1071 : kernel_spatial[i];
1072 }
1073
1074 kernel_multi_index[dnums.kernel_input_feature_dimension()] = input_feature;
1075 kernel_multi_index[dnums.kernel_output_feature_dimension()] = output_feature;
1076
1077 llvm_ir::IrArray::Index input_index(input_multi_index, lhs->shape(),
1078 b_.getInt64Ty());
1079 TF_ASSIGN_OR_RETURN(llvm::Value* const input_value,
1080 input_generator(input_index));
1081 llvm_ir::IrArray::Index kernel_index(kernel_multi_index, rhs->shape(),
1082 b_.getInt64Ty());
1083 TF_ASSIGN_OR_RETURN(llvm::Value* const kernel_value,
1084 kernel_generator(kernel_index));
1085 llvm::Value* product = FMul(input_value, kernel_value);
1086 llvm::Value* sum = FAdd(Load(sum_address), product);
1087 Store(sum, sum_address);
1088
1089 SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
1090 return Load(sum_address);
1091 }
1092
HandleConvolution(HloInstruction * convolution)1093 Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
1094 auto lhs = convolution->operand(0);
1095 auto rhs = convolution->operand(1);
1096 TF_RETURN_IF_ERROR(ElementTypesSameAndSupported(
1097 /*instruction=*/*convolution, /*operands=*/{lhs, rhs},
1098 /*supported_types=*/{F16, F32, F64, C64, C128}));
1099
1100 // TODO(tonywy): Add PotentiallyImplementedAsMKLCovolution to support
1101 // different data layouts.
1102 if (PotentiallyImplementedAsEigenConvolution(*convolution,
1103 target_machine_features_)) {
1104 const Shape& lhs_shape = lhs->shape();
1105 const Shape& rhs_shape = rhs->shape();
1106 const Shape& convolution_shape = convolution->shape();
1107 // The input, kernel and output agree with respect to layout.
1108 if (LayoutUtil::IsMonotonicWithDim0Major(lhs_shape.layout()) &&
1109 LayoutUtil::IsMonotonicWithDim0Major(rhs_shape.layout()) &&
1110 LayoutUtil::IsMonotonicWithDim0Major(convolution_shape.layout())) {
1111 // We lower 1D convolutions into calls to the same Eigen function as 2D
1112 // convolutions, except that we pretend that the 1D convolution is really
1113 // a 2D convolution with the missing dimension set to 1. We also adjust
1114 // the padding, dilation parameters as needed.
1115 bool one_dim_convolution = lhs_shape.dimensions_size() == 3;
1116 llvm::Value* lhs_address = GetEmittedValueFor(lhs);
1117 llvm::Value* rhs_address = GetEmittedValueFor(rhs);
1118 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(convolution));
1119
1120 const ConvolutionDimensionNumbers& dnums =
1121 convolution->convolution_dimension_numbers();
1122
1123 // Input tensor.
1124 const Shape& input_shape = convolution->operand(0)->shape();
1125 int64 input_batch = input_shape.dimensions(dnums.input_batch_dimension());
1126 int64 input_rows =
1127 input_shape.dimensions(dnums.input_spatial_dimensions(0));
1128 int64 input_cols =
1129 one_dim_convolution
1130 ? 1
1131 : input_shape.dimensions(dnums.input_spatial_dimensions(1));
1132 int64 input_channels =
1133 input_shape.dimensions(dnums.input_feature_dimension());
1134
1135 // Kernel tensor.
1136 const Shape& kernel_shape = convolution->operand(1)->shape();
1137 int64 kernel_rows =
1138 kernel_shape.dimensions(dnums.kernel_spatial_dimensions(0));
1139 int64 kernel_cols =
1140 one_dim_convolution
1141 ? 1
1142 : kernel_shape.dimensions(dnums.kernel_spatial_dimensions(1));
1143 int64 kernel_channels =
1144 kernel_shape.dimensions(dnums.kernel_input_feature_dimension());
1145 int64 kernel_filters =
1146 kernel_shape.dimensions(dnums.kernel_output_feature_dimension());
1147
1148 // Output tensor.
1149 const Shape& convolution_shape = convolution->shape();
1150 int64 output_rows =
1151 convolution_shape.dimensions(dnums.output_spatial_dimensions(0));
1152 int64 output_cols = one_dim_convolution
1153 ? 1
1154 : convolution_shape.dimensions(
1155 dnums.output_spatial_dimensions(1));
1156
1157 // Extract the window stride for the convolution.
1158 const Window& window = convolution->window();
1159 int64 row_stride = window.dimensions(0).stride();
1160 int64 col_stride =
1161 one_dim_convolution ? 1 : window.dimensions(1).stride();
1162
1163 int64 padding_top = window.dimensions(0).padding_low();
1164 int64 padding_bottom = window.dimensions(0).padding_high();
1165 int64 padding_left =
1166 one_dim_convolution ? 0 : window.dimensions(1).padding_low();
1167 int64 padding_right =
1168 one_dim_convolution ? 0 : window.dimensions(1).padding_high();
1169
1170 int64 lhs_row_dilation = window.dimensions(0).base_dilation();
1171 int64 lhs_col_dilation =
1172 one_dim_convolution ? 1 : window.dimensions(1).base_dilation();
1173 int64 rhs_row_dilation = window.dimensions(0).window_dilation();
1174 int64 rhs_col_dilation =
1175 one_dim_convolution ? 1 : window.dimensions(1).window_dilation();
1176
1177 PrimitiveType primitive_type = lhs->shape().element_type();
1178 llvm::Type* ir_ptr_type = primitive_type == F16
1179 ? b_.getHalfTy()->getPointerTo()
1180 : b_.getFloatTy()->getPointerTo();
1181 llvm::Type* int64_type = b_.getInt64Ty();
1182 llvm::Type* int8_ptr_type = b_.getInt8Ty()->getPointerTo();
1183 llvm::FunctionType* conv_type = llvm::FunctionType::get(
1184 b_.getVoidTy(),
1185 {int8_ptr_type, ir_ptr_type, ir_ptr_type, ir_ptr_type, int64_type,
1186 int64_type, int64_type, int64_type, int64_type, int64_type,
1187 int64_type, int64_type, int64_type, int64_type, int64_type,
1188 int64_type, int64_type, int64_type, int64_type, int64_type,
1189 int64_type, int64_type, int64_type, int64_type},
1190 /*isVarArg=*/false);
1191 bool multi_threaded =
1192 hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen();
1193 bool use_mkl_dnn =
1194 hlo_module_config_.debug_options().xla_cpu_use_mkl_dnn();
1195
1196 // TODO(b/78639006) Singlethread MKL conv2d is not implemented due to the
1197 // potential race condition by setting the omp_num_threads.
1198 const char* fn_name =
1199 primitive_type == F16
1200 ? (multi_threaded
1201 ? runtime::kEigenConvF16SymbolName
1202 : runtime::kEigenSingleThreadedConvF16SymbolName)
1203 : (multi_threaded
1204 ? (use_mkl_dnn ? runtime::kMKLConvF32SymbolName
1205 : runtime::kEigenConvF32SymbolName)
1206 : runtime::kEigenSingleThreadedConvF32SymbolName);
1207 if (!multi_threaded && use_mkl_dnn) {
1208 LOG(WARNING) << "Using Eigen instead of MKL-DNN for single-threaded "
1209 "conv2d function.";
1210 }
1211 llvm::Function* conv_func = llvm::dyn_cast<llvm::Function>(
1212 module_->getOrInsertFunction(fn_name, conv_type).getCallee());
1213 conv_func->setCallingConv(llvm::CallingConv::C);
1214 conv_func->setDoesNotThrow();
1215 conv_func->setOnlyAccessesArgMemory();
1216 Call(conv_func, {
1217 GetExecutableRunOptionsArgument(),
1218 BitCast(GetEmittedValueFor(convolution), ir_ptr_type),
1219 BitCast(lhs_address, ir_ptr_type),
1220 BitCast(rhs_address, ir_ptr_type),
1221 b_.getInt64(input_batch),
1222 b_.getInt64(input_rows),
1223 b_.getInt64(input_cols),
1224 b_.getInt64(input_channels),
1225 b_.getInt64(kernel_rows),
1226 b_.getInt64(kernel_cols),
1227 b_.getInt64(kernel_channels),
1228 b_.getInt64(kernel_filters),
1229 b_.getInt64(output_rows),
1230 b_.getInt64(output_cols),
1231 b_.getInt64(row_stride),
1232 b_.getInt64(col_stride),
1233 b_.getInt64(padding_top),
1234 b_.getInt64(padding_bottom),
1235 b_.getInt64(padding_left),
1236 b_.getInt64(padding_right),
1237 b_.getInt64(lhs_row_dilation),
1238 b_.getInt64(lhs_col_dilation),
1239 b_.getInt64(rhs_row_dilation),
1240 b_.getInt64(rhs_col_dilation),
1241 });
1242
1243 return Status::OK();
1244 }
1245 }
1246
1247 // This is a completely un-optimized version of convolution just to
1248 // have an early version that works. E.g. the input index and
1249 // padding calculation is not hoisted out of the inner loop.
1250 //
1251 // See the description of convolution in the XLA documentation for the pseudo
1252 // code for convolution.
1253 return DefaultAction(convolution);
1254 }
1255
HandleFft(HloInstruction * fft)1256 Status IrEmitter::HandleFft(HloInstruction* fft) {
1257 auto operand = fft->operand(0);
1258 TF_RETURN_IF_ERROR(ElementTypesSameAndSupported(
1259 /*instruction=*/*fft, /*operands=*/{operand},
1260 /*supported_types=*/{F32, C64}));
1261 TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(operand->shape().layout()));
1262 TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(fft->shape().layout()));
1263 VLOG(3) << "operand=" << ShapeUtil::HumanStringWithLayout(operand->shape());
1264 VLOG(3) << "fft=" << ShapeUtil::HumanStringWithLayout(fft->shape());
1265
1266 llvm::Value* operand_address = GetEmittedValueFor(operand);
1267 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(fft));
1268
1269 const std::vector<int64>& fft_length = fft->fft_length();
1270 int64 input_batch = 1;
1271 for (int i = 0; i < fft->shape().dimensions_size() - fft_length.size(); i++) {
1272 input_batch *= fft->shape().dimensions(i);
1273 }
1274
1275 // Args have been computed, make the call.
1276 llvm::Type* int8_ptr_type = b_.getInt8Ty()->getPointerTo();
1277 llvm::Type* int32_type = b_.getInt32Ty();
1278 llvm::Type* int64_type = b_.getInt64Ty();
1279 llvm::FunctionType* fft_type = llvm::FunctionType::get(
1280 b_.getVoidTy(),
1281 {int8_ptr_type, int8_ptr_type, int8_ptr_type, int32_type, int32_type,
1282 int64_type, int64_type, int64_type, int64_type},
1283 /*isVarArg=*/false);
1284
1285 bool multi_threaded_eigen =
1286 hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen();
1287 const char* fn_name = multi_threaded_eigen
1288 ? runtime::kEigenFftSymbolName
1289 : runtime::kEigenSingleThreadedFftSymbolName;
1290
1291 llvm::Function* fft_func = llvm::dyn_cast<llvm::Function>(
1292 module_->getOrInsertFunction(fn_name, fft_type).getCallee());
1293 fft_func->setCallingConv(llvm::CallingConv::C);
1294 fft_func->setDoesNotThrow();
1295 fft_func->setOnlyAccessesInaccessibleMemOrArgMem();
1296 const int fft_rank = fft_length.size();
1297 Call(fft_func,
1298 {GetExecutableRunOptionsArgument(),
1299 BitCast(GetEmittedValueFor(fft), int8_ptr_type),
1300 BitCast(operand_address, int8_ptr_type), b_.getInt32(fft->fft_type()),
1301 b_.getInt32(fft_rank), b_.getInt64(input_batch),
1302 b_.getInt64(fft_rank > 0 ? fft_length[0] : 0),
1303 b_.getInt64(fft_rank > 1 ? fft_length[1] : 0),
1304 b_.getInt64(fft_rank > 2 ? fft_length[2] : 0)});
1305
1306 return Status::OK();
1307 }
1308
HandleAllReduce(HloInstruction * crs)1309 Status IrEmitter::HandleAllReduce(HloInstruction* crs) {
1310 if (hlo_module_config_.replica_count() != 1) {
1311 // TODO(b/33011107): Support nontrivial cross replica sum on CPU.
1312 return Unimplemented(
1313 "AllReduce with >1 replica is not implemented on CPU.");
1314 }
1315
1316 // When there is a single replica, a cross replica sum is the identity
1317 // function, and the buffer assignment expects a copy.
1318 //
1319 // TODO(b/80100934): We would like to eliminate one-replica CRS nodes entirely
1320 // in algebraic-simplifier, but currently on some platforms
1321 // HloModuleConfig::num_replicas changes between when the module is compiled
1322 // and when it's run.
1323 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(crs));
1324
1325 // CRS with one operand and one replica is simply the identity function.
1326 if (crs->operand_count() == 1) {
1327 return EmitMemcpy(*crs->operand(0), *crs);
1328 }
1329
1330 // CRS with multiple operands and one replica produces a (one-deep) tuple.
1331 std::vector<llvm::Value*> operand_ptrs;
1332 for (int64 i = 0; i < crs->operand_count(); ++i) {
1333 llvm::Value* in_ptr = GetEmittedValueFor(crs->operand(i));
1334 TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice out_slice,
1335 assignment_.GetUniqueSlice(crs, {i}));
1336
1337 const Shape& operand_shape = crs->operand(i)->shape();
1338 CHECK(operand_shape.IsArray())
1339 << "Operands to all-reduce must be arrays: " << crs->ToString();
1340 operand_ptrs.push_back(EmitBufferPointer(out_slice, operand_shape));
1341
1342 // TODO(b/63762267): Be more aggressive about specifying alignment.
1343 MemCpy(operand_ptrs.back(), /*DstAlign=*/1, in_ptr,
1344 /*SrcAlign=*/1, ShapeUtil::ByteSizeOf(operand_shape));
1345 }
1346 llvm_ir::EmitTuple(GetIrArrayFor(crs), operand_ptrs, &b_);
1347 return Status::OK();
1348 }
1349
HandleParameter(HloInstruction * parameter)1350 Status IrEmitter::HandleParameter(HloInstruction* parameter) {
1351 VLOG(2) << "HandleParameter: " << parameter->ToString();
1352 return EmitTargetAddressForOp(parameter);
1353 }
1354
1355 // Returns true if the relative order of the unreduced dimensions stays the same
1356 // through the reduce operation.
ReductionPreservesLayout(const HloInstruction & reduce)1357 static bool ReductionPreservesLayout(const HloInstruction& reduce) {
1358 DCHECK_EQ(reduce.opcode(), HloOpcode::kReduce);
1359
1360 // Maps dimensions that were not reduced from their dimension numbers in the
1361 // source shape to their dimensions numbers in the destination shape.
1362 //
1363 // So if we reduce f32[A,B,C,D] on dimensions 1 and 2, this map contains
1364 // [0->0, 3->1].
1365 absl::flat_hash_map<int64, int64> unreduced_dim_map;
1366
1367 absl::flat_hash_set<int64> reduced_dims(reduce.dimensions().begin(),
1368 reduce.dimensions().end());
1369
1370 const Shape& operand_shape = reduce.operand(0)->shape();
1371 const Shape& result_shape = reduce.shape();
1372
1373 int64 delta = 0;
1374 for (int64 i = 0; i < operand_shape.dimensions_size(); i++) {
1375 if (reduced_dims.contains(i)) {
1376 delta++;
1377 } else {
1378 InsertOrDie(&unreduced_dim_map, i, i - delta);
1379 }
1380 }
1381
1382 // Iterate dimensions minor to major and check that the corresponding
1383 // dimensions in the source and target shapes are equivalent.
1384 int64 result_dim_idx = 0;
1385 for (int64 operand_dim_idx = 0;
1386 operand_dim_idx < operand_shape.dimensions_size(); operand_dim_idx++) {
1387 int64 operand_dim = operand_shape.layout().minor_to_major(operand_dim_idx);
1388 if (!reduced_dims.contains(operand_dim)) {
1389 if (FindOrDie(unreduced_dim_map, operand_dim) !=
1390 result_shape.layout().minor_to_major(result_dim_idx++)) {
1391 return false;
1392 }
1393 }
1394 }
1395
1396 CHECK_EQ(result_dim_idx, result_shape.dimensions_size());
1397
1398 return true;
1399 }
1400
MatchReductionGenerator(HloComputation * function,string * failure_reason) const1401 IrEmitter::ReductionGenerator IrEmitter::MatchReductionGenerator(
1402 HloComputation* function, string* failure_reason) const {
1403 CHECK_EQ(function->num_parameters(), 2);
1404
1405 auto root_instruction = function->root_instruction();
1406 CHECK(ShapeUtil::IsScalar(root_instruction->shape()));
1407
1408 if (root_instruction->operand_count() != 2) {
1409 *failure_reason = "root instruction is not a binary operation";
1410 return nullptr;
1411 }
1412
1413 const Shape& root_shape = root_instruction->shape();
1414 if (ShapeUtil::ElementIsComplex(root_shape)) {
1415 // TODO(b/65408531): Complex add could by done via bitcast to <float x [2N]>
1416 // Complex multiply would be more challenging. We could perhaps use a
1417 // strided load to get all reals in a vector, all images in a vector, or use
1418 // CreateShuffleVector on a bitcast to float x [2N].
1419 *failure_reason = "complex values not supported";
1420 return nullptr;
1421 }
1422 bool root_is_floating_point = ShapeUtil::ElementIsFloating(root_shape);
1423 bool root_is_integral = ShapeUtil::ElementIsIntegral(root_shape);
1424 bool root_is_signed = ShapeUtil::ElementIsSigned(root_shape);
1425
1426 auto lhs = root_instruction->operand(0);
1427 auto rhs = root_instruction->operand(1);
1428
1429 auto param_0 = function->parameter_instruction(0);
1430 auto param_1 = function->parameter_instruction(1);
1431 if (!(lhs == param_0 && rhs == param_1) &&
1432 !(rhs == param_0 && lhs == param_1)) {
1433 *failure_reason =
1434 "root instruction is not a binary operation on the incoming arguments";
1435 return nullptr;
1436 }
1437
1438 CHECK(ShapeUtil::IsScalar(lhs->shape()) && ShapeUtil::IsScalar(rhs->shape()));
1439
1440 // This is visually similar to ElementalIrEmitter, though conceptually we're
1441 // doing something different here. ElementalIrEmitter emits scalar operations
1442 // while these emit scalar or vector operations depending on the type of the
1443 // operands. See CreateShardedVectorType for the actual types in use here.
1444 switch (root_instruction->opcode()) {
1445 default:
1446 *failure_reason = "did not recognize root instruction opcode";
1447 return nullptr;
1448
1449 case HloOpcode::kAdd:
1450 return [root_is_integral](llvm::IRBuilder<>* b, llvm::Value* lhs,
1451 llvm::Value* rhs) {
1452 return root_is_integral ? b->CreateAdd(lhs, rhs)
1453 : b->CreateFAdd(lhs, rhs);
1454 };
1455
1456 case HloOpcode::kMultiply:
1457 return [root_is_integral](llvm::IRBuilder<>* b, llvm::Value* lhs,
1458 llvm::Value* rhs) {
1459 return root_is_integral ? b->CreateMul(lhs, rhs)
1460 : b->CreateFMul(lhs, rhs);
1461 };
1462
1463 case HloOpcode::kAnd:
1464 return [](llvm::IRBuilder<>* b, llvm::Value* lhs, llvm::Value* rhs) {
1465 return b->CreateAnd(lhs, rhs);
1466 };
1467
1468 case HloOpcode::kOr:
1469 return [](llvm::IRBuilder<>* b, llvm::Value* lhs, llvm::Value* rhs) {
1470 return b->CreateOr(lhs, rhs);
1471 };
1472
1473 case HloOpcode::kXor:
1474 return [](llvm::IRBuilder<>* b, llvm::Value* lhs, llvm::Value* rhs) {
1475 return b->CreateXor(lhs, rhs);
1476 };
1477
1478 case HloOpcode::kMaximum:
1479 return [root_is_floating_point, root_is_signed](
1480 llvm::IRBuilder<>* b, llvm::Value* lhs,
1481 llvm::Value* rhs) -> llvm::Value* {
1482 if (root_is_floating_point) {
1483 return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::maxnum,
1484 {lhs, rhs}, {lhs->getType()}, b);
1485 }
1486
1487 return b->CreateSelect(
1488 b->CreateICmp(root_is_signed ? llvm::ICmpInst::ICMP_SGE
1489 : llvm::ICmpInst::ICMP_UGE,
1490 lhs, rhs),
1491 lhs, rhs);
1492 };
1493
1494 case HloOpcode::kMinimum:
1495 return [root_is_floating_point, root_is_signed](
1496 llvm::IRBuilder<>* b, llvm::Value* lhs,
1497 llvm::Value* rhs) -> llvm::Value* {
1498 if (root_is_floating_point) {
1499 return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::minnum,
1500 {lhs, rhs}, {lhs->getType()}, b);
1501 }
1502
1503 return b->CreateSelect(
1504 b->CreateICmp(root_is_signed ? llvm::ICmpInst::ICMP_SLE
1505 : llvm::ICmpInst::ICMP_ULE,
1506 lhs, rhs),
1507 lhs, rhs);
1508 };
1509 }
1510 }
1511
CreateShardedVectorType(PrimitiveType element_type,unsigned element_count)1512 IrEmitter::ShardedVectorType IrEmitter::CreateShardedVectorType(
1513 PrimitiveType element_type, unsigned element_count) {
1514 int vector_register_size_in_elements =
1515 target_machine_features_.vector_register_byte_size(
1516 *compute_function_->function()) /
1517 ShapeUtil::ByteSizeOfPrimitiveType(element_type);
1518
1519 ShardedVectorType sharded_vector_type;
1520 llvm::Type* element_ir_type =
1521 llvm_ir::PrimitiveTypeToIrType(element_type, module_);
1522
1523 for (int i = 0, e = 1 + tensorflow::Log2Ceiling(element_count); i < e; i++) {
1524 // For every power of two present in element_count, we generate one or more
1525 // vector or scalar types.
1526 const unsigned current_size_fragment = 1u << i;
1527 if (!(element_count & current_size_fragment)) {
1528 // Power of two not present in element_count.
1529 continue;
1530 }
1531
1532 if (current_size_fragment == 1) {
1533 // Single element, use a scalar type.
1534 sharded_vector_type.push_back(element_ir_type);
1535 continue;
1536 }
1537
1538 // Lower "current_size_fragment" number of elements using (as few as
1539 // possible) vector registers.
1540
1541 if (current_size_fragment >= vector_register_size_in_elements) {
1542 auto vector_type = llvm::VectorType::get(
1543 element_ir_type, vector_register_size_in_elements);
1544 sharded_vector_type.insert(
1545 sharded_vector_type.end(),
1546 current_size_fragment / vector_register_size_in_elements,
1547 vector_type);
1548
1549 // Both current_size_fragment and vector_register_size_in_elements are
1550 // powers of two.
1551 CHECK_EQ(current_size_fragment % vector_register_size_in_elements, 0);
1552 continue;
1553 }
1554
1555 // For now we assume that vector_register_size_in_elements and lower powers
1556 // of two are all legal vector sizes (or at least can be lowered easily by
1557 // LLVM).
1558 sharded_vector_type.push_back(
1559 llvm::VectorType::get(element_ir_type, current_size_fragment));
1560 }
1561 return sharded_vector_type;
1562 }
1563
1564 StatusOr<IrEmitter::ShardedVector>
EmitInnerLoopForVectorizedReduction(const ReductionGenerator & reduction_generator,const llvm_ir::IrArray::Index & output_index,const ShardedVectorType & accumulator_type,HloInstruction * init_value,HloInstruction * arg,absl::Span<const int64> dimensions,unsigned element_alignment)1565 IrEmitter::EmitInnerLoopForVectorizedReduction(
1566 const ReductionGenerator& reduction_generator,
1567 const llvm_ir::IrArray::Index& output_index,
1568 const ShardedVectorType& accumulator_type, HloInstruction* init_value,
1569 HloInstruction* arg, absl::Span<const int64> dimensions,
1570 unsigned element_alignment) {
1571 ShardedVector accumulator;
1572 accumulator.reserve(accumulator_type.size());
1573 for (auto accumulator_shard_type : accumulator_type) {
1574 accumulator.push_back(llvm_ir::EmitAllocaAtFunctionEntry(
1575 accumulator_shard_type, "accumulator", &b_, 0));
1576 }
1577
1578 llvm::Value* init_value_ssa = Load(GetEmittedValueFor(init_value));
1579
1580 for (llvm::Value* accumulator_shard : accumulator) {
1581 llvm::Value* initial_value;
1582 auto shard_type = accumulator_shard->getType()->getPointerElementType();
1583 if (auto vector_type = llvm::dyn_cast<llvm::VectorType>(shard_type)) {
1584 initial_value =
1585 VectorSplat(vector_type->getNumElements(), init_value_ssa);
1586 } else {
1587 initial_value = init_value_ssa;
1588 }
1589
1590 AlignedStore(initial_value, accumulator_shard, element_alignment);
1591 }
1592
1593 llvm_ir::ForLoopNest reduction_loop_nest(IrName(arg, "vectorized_inner"),
1594 &b_);
1595 std::vector<llvm::Value*> input_multi_index =
1596 reduction_loop_nest.AddLoopsForShapeOnDimensions(arg->shape(), dimensions,
1597 "reduction_dim");
1598
1599 SetToFirstInsertPoint(reduction_loop_nest.GetInnerLoopBodyBasicBlock(), &b_);
1600
1601 llvm_ir::IrArray arg_array(GetIrArrayFor(arg));
1602 llvm_ir::IrArray::Index::const_iterator it = output_index.begin();
1603
1604 for (auto& i : input_multi_index) {
1605 if (i == nullptr) {
1606 i = *it++;
1607 }
1608 }
1609 CHECK(output_index.end() == it);
1610 llvm_ir::IrArray::Index input_index(input_multi_index, arg->shape(),
1611 b_.getInt64Ty());
1612
1613 llvm::Value* input_address = BitCast(
1614 arg_array.EmitArrayElementAddress(input_index, &b_), b_.getInt8PtrTy());
1615
1616 for (int i = 0; i < accumulator.size(); i++) {
1617 auto input_address_typed =
1618 BitCast(input_address, accumulator[i]->getType());
1619 auto current_accumulator_value =
1620 AlignedLoad(accumulator[i], element_alignment);
1621 auto addend = AlignedLoad(input_address_typed, element_alignment);
1622 arg_array.AnnotateLoadStoreInstructionWithMetadata(addend);
1623
1624 auto reduced_result =
1625 reduction_generator(&b_, current_accumulator_value, addend);
1626 AlignedStore(reduced_result, accumulator[i], element_alignment);
1627
1628 if (i != (accumulator.size() - 1)) {
1629 input_address = ConstInBoundsGEP1_32(reduced_result->getType(),
1630 input_address_typed, 1);
1631 }
1632 }
1633
1634 SetToFirstInsertPoint(reduction_loop_nest.GetOuterLoopExitBasicBlock(), &b_);
1635
1636 ShardedVector result_ssa;
1637 result_ssa.reserve(accumulator.size());
1638 for (auto accumulator_shard : accumulator) {
1639 result_ssa.push_back(AlignedLoad(accumulator_shard, element_alignment));
1640 }
1641 return result_ssa;
1642 }
1643
EmitShardedVectorStore(llvm::Value * store_address,const std::vector<llvm::Value * > & value_to_store,const int alignment,const llvm_ir::IrArray & containing_array)1644 void IrEmitter::EmitShardedVectorStore(
1645 llvm::Value* store_address, const std::vector<llvm::Value*>& value_to_store,
1646 const int alignment, const llvm_ir::IrArray& containing_array) {
1647 for (int i = 0; i < value_to_store.size(); i++) {
1648 auto store_address_typed =
1649 BitCast(store_address,
1650 llvm::PointerType::getUnqual(value_to_store[i]->getType()));
1651
1652 auto store_instruction =
1653 AlignedStore(value_to_store[i], store_address_typed, alignment);
1654 containing_array.AnnotateLoadStoreInstructionWithMetadata(
1655 store_instruction);
1656
1657 if (i != (value_to_store.size() - 1)) {
1658 store_address = ConstInBoundsGEP1_32(value_to_store[i]->getType(),
1659 store_address_typed, 1);
1660 }
1661 }
1662 }
1663
EmitVectorizedReduce(HloInstruction * reduce,HloInstruction * arg,HloInstruction * init_value,absl::Span<const int64> dimensions,HloComputation * function,string * failure_reason)1664 StatusOr<bool> IrEmitter::EmitVectorizedReduce(
1665 HloInstruction* reduce, HloInstruction* arg, HloInstruction* init_value,
1666 absl::Span<const int64> dimensions, HloComputation* function,
1667 string* failure_reason) {
1668 if (!ReductionPreservesLayout(*reduce)) {
1669 return false;
1670 }
1671
1672 ReductionGenerator reduction_generator =
1673 MatchReductionGenerator(function, failure_reason);
1674 if (!reduction_generator) {
1675 return false;
1676 }
1677
1678 int vectorization_factor_in_bytes =
1679 target_machine_features_.vectorization_factor_in_bytes();
1680
1681 // We try to process vectorization_factor elements at the same time.
1682 const int vectorization_factor =
1683 vectorization_factor_in_bytes /
1684 ShapeUtil::ByteSizeOfPrimitiveType(reduce->shape().element_type());
1685
1686 bool is_reduction_over_minor_dimension = absl::c_linear_search(
1687 dimensions, LayoutUtil::Minor(arg->shape().layout(), 0));
1688
1689 unsigned element_alignment = tensorflow::MathUtil::GCD<unsigned>(
1690 ShapeUtil::ByteSizeOfPrimitiveType(reduce->shape().element_type()),
1691 MinimumAlignmentForPrimitiveType(reduce->shape().element_type()));
1692
1693 if (is_reduction_over_minor_dimension) {
1694 // TODO(sanjoy): Implement vectorized reduction over the minor dimension.
1695 *failure_reason = "reduction over minor dimension not implemented";
1696 return false;
1697 }
1698
1699 CHECK(!reduce->shape().IsTuple());
1700 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(reduce));
1701
1702 // We know we're not reducing over the most minor dimension, which means we
1703 // can lower the reduction loop as:
1704 //
1705 // 1. We're reducing over dimensions R0, R1.
1706 // 2. D0 is the most minor dimension.
1707 // 3. VS is the vectorization stride (we want to reduce this many elements at
1708 // once)
1709 //
1710 // for (d1 in D1) {
1711 // for (d0 in D0 with stride VS) {
1712 // vector_acc = init
1713 // for (r1 in R1) {
1714 // for (r0 in R0) {
1715 // vector_acc = elementwise_reduce(vector_acc, input[d1, d0, r1, r0]
1716 // }
1717 // }
1718 // output[d1, d0] = vector_acc
1719 // }
1720 // }
1721
1722 llvm_ir::ForLoopNest loop_nest(IrName(reduce), &b_);
1723 std::vector<llvm::Value*> array_multi_index(
1724 reduce->shape().dimensions_size());
1725 for (int i = LayoutUtil::MinorToMajor(reduce->shape()).size() - 1; i > 0;
1726 --i) {
1727 int64 dimension = LayoutUtil::Minor(reduce->shape().layout(), i);
1728 int64 start_index = 0;
1729 int64 end_index = reduce->shape().dimensions(dimension);
1730 std::unique_ptr<llvm_ir::ForLoop> loop = loop_nest.AddLoop(
1731 start_index, end_index, absl::StrFormat("dim.%d", dimension));
1732 array_multi_index[dimension] = loop->GetIndVarValue();
1733 }
1734
1735 int64 innermost_dimension = LayoutUtil::Minor(reduce->shape().layout(), 0);
1736 int64 innermost_dimension_size =
1737 reduce->shape().dimensions(innermost_dimension);
1738
1739 if (llvm::BasicBlock* innermost_body_bb =
1740 loop_nest.GetInnerLoopBodyBasicBlock()) {
1741 SetToFirstInsertPoint(innermost_body_bb, &b_);
1742 }
1743
1744 auto outermost_loop_exit_block = loop_nest.GetOuterLoopExitBasicBlock();
1745
1746 if (innermost_dimension_size >= vectorization_factor) {
1747 int64 start_index = 0;
1748 int64 end_index = (innermost_dimension_size / vectorization_factor) *
1749 vectorization_factor;
1750 std::unique_ptr<llvm_ir::ForLoop> loop =
1751 loop_nest.AddLoop(start_index, end_index, vectorization_factor,
1752 absl::StrFormat("dim.%d", innermost_dimension));
1753 array_multi_index[innermost_dimension] = loop->GetIndVarValue();
1754
1755 SetToFirstInsertPoint(loop->GetBodyBasicBlock(), &b_);
1756
1757 ShardedVectorType vector_type = CreateShardedVectorType(
1758 reduce->shape().element_type(), vectorization_factor);
1759 llvm_ir::IrArray::Index array_index(array_multi_index, reduce->shape(),
1760 b_.getInt64Ty());
1761 TF_ASSIGN_OR_RETURN(std::vector<llvm::Value*> accumulator,
1762 EmitInnerLoopForVectorizedReduction(
1763 reduction_generator, array_index, vector_type,
1764 init_value, arg, dimensions, element_alignment));
1765
1766 llvm_ir::IrArray target_array = GetIrArrayFor(reduce);
1767 llvm::Value* output_address =
1768 target_array.EmitArrayElementAddress(array_index, &b_);
1769 EmitShardedVectorStore(output_address, accumulator, element_alignment,
1770 target_array);
1771
1772 if (auto exit_terminator = loop->GetExitBasicBlock()->getTerminator()) {
1773 CHECK_GT(LayoutUtil::MinorToMajor(reduce->shape()).size(), 1);
1774 b_.SetInsertPoint(exit_terminator);
1775 } else {
1776 CHECK_EQ(LayoutUtil::MinorToMajor(reduce->shape()).size(), 1);
1777 b_.SetInsertPoint(loop->GetExitBasicBlock());
1778 }
1779 }
1780
1781 // Since we increment the stride for the inner dimension by more than 1, we
1782 // may need to peel out an "epilogue" iteration to get the remaining elements
1783 // in the following case:
1784 if (innermost_dimension_size % vectorization_factor) {
1785 // TODO(b/63775531): Consider using a scalar loop here to save on code size.
1786 array_multi_index[innermost_dimension] =
1787 b_.getInt64(innermost_dimension_size -
1788 (innermost_dimension_size % vectorization_factor));
1789
1790 ShardedVectorType vector_type = CreateShardedVectorType(
1791 reduce->shape().element_type(),
1792 innermost_dimension_size % vectorization_factor);
1793 llvm_ir::IrArray::Index array_index(array_multi_index, reduce->shape(),
1794 b_.getInt64Ty());
1795 TF_ASSIGN_OR_RETURN(std::vector<llvm::Value*> accumulator,
1796 EmitInnerLoopForVectorizedReduction(
1797 reduction_generator, array_index, vector_type,
1798 init_value, arg, dimensions, element_alignment));
1799
1800 llvm_ir::IrArray target_array = GetIrArrayFor(reduce);
1801 llvm::Value* output_address =
1802 target_array.EmitArrayElementAddress(array_index, &b_);
1803 EmitShardedVectorStore(output_address, accumulator, element_alignment,
1804 target_array);
1805 }
1806
1807 if (outermost_loop_exit_block) {
1808 b_.SetInsertPoint(outermost_loop_exit_block);
1809 }
1810
1811 return true;
1812 }
1813
EmitElementalReduce(const HloReduceInstruction * reduce,const llvm_ir::ElementGenerator & input_generator,const llvm_ir::ElementGenerator & initial_value_generator,const llvm_ir::IrArray::Index & index)1814 StatusOr<llvm::Value*> IrEmitter::EmitElementalReduce(
1815 const HloReduceInstruction* reduce,
1816 const llvm_ir::ElementGenerator& input_generator,
1817 const llvm_ir::ElementGenerator& initial_value_generator,
1818 const llvm_ir::IrArray::Index& index) {
1819 const HloInstruction* arg = reduce->operand(0);
1820 absl::Span<const int64> dimensions(reduce->dimensions());
1821
1822 // Initialize an accumulator with init_value.
1823 PrimitiveType accumulator_type = reduce->shape().element_type();
1824 llvm::AllocaInst* accumulator_addr = llvm_ir::EmitAllocaAtFunctionEntry(
1825 llvm_ir::PrimitiveTypeToIrType(accumulator_type, module_), "accumulator",
1826 &b_, MinimumAlignmentForPrimitiveType(accumulator_type));
1827 TF_ASSIGN_OR_RETURN(
1828 llvm::Value* const init_value,
1829 initial_value_generator(llvm_ir::IrArray::Index(index.GetType())));
1830 Store(init_value, accumulator_addr);
1831
1832 // The enclosing loops go over all the target elements. Now we have to compute
1833 // the actual target element. For this, we build a new loop nest to iterate
1834 // over all the reduction dimensions in the argument.
1835 // AddLoopsForShapeOnDimensions will return an Index where induction Value*s
1836 // are placed for each dimension in dimensions, and all the rest are nullptrs.
1837 llvm_ir::ForLoopNest loops(IrName(reduce, "inner"), &b_);
1838 std::vector<llvm::Value*> input_multi_index =
1839 loops.AddLoopsForShapeOnDimensions(arg->shape(), dimensions,
1840 "reduction_dim");
1841
1842 SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_);
1843
1844 // Build a full index for the input argument, using reduced_dims_index as the
1845 // base. In reduced_dims_index only the reduction dimensions are filled in. We
1846 // fill in the rest of the dimensions with induction Value*s taken from
1847 // 'index' which iterates over the target array. See the high-level
1848 // description in the XLA documentation for details.
1849 llvm_ir::IrArray::Index::const_iterator it = index.begin();
1850
1851 for (auto& i : input_multi_index) {
1852 if (i == nullptr) {
1853 i = *it++;
1854 }
1855 }
1856 CHECK(index.end() == it);
1857 llvm_ir::IrArray::Index input_index(input_multi_index, arg->shape(),
1858 b_.getInt64Ty());
1859
1860 // Apply the reduction function to the loaded value.
1861 TF_ASSIGN_OR_RETURN(llvm::Value* const input_element,
1862 input_generator(input_index));
1863 llvm::Value* result = EmitThreadLocalCall(
1864 *reduce->to_apply(), {Load(accumulator_addr), input_element},
1865 "reduce_function");
1866 Store(result, accumulator_addr);
1867
1868 SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
1869 return Load(accumulator_addr);
1870 }
1871
HandleReduce(HloInstruction * reduce)1872 Status IrEmitter::HandleReduce(HloInstruction* reduce) {
1873 // TODO(b/118333695): Support variadic reduce.
1874 if (!reduce->shape().IsArray()) {
1875 return Unimplemented("Variadic reduce is not supported on CPU");
1876 }
1877 auto arg = reduce->mutable_operand(0);
1878 auto init_value = reduce->mutable_operand(1);
1879 absl::Span<const int64> dimensions(reduce->dimensions());
1880 HloComputation* function = reduce->to_apply();
1881 if (!options::VectorizedReduceDisabled(hlo_module_config_)) {
1882 string vectorization_failure_reason;
1883 TF_ASSIGN_OR_RETURN(
1884 bool vectorization_successful,
1885 EmitVectorizedReduce(reduce, arg, init_value, dimensions, function,
1886 &vectorization_failure_reason));
1887 if (vectorization_successful) {
1888 VLOG(1) << "Successfully vectorized reduction " << reduce->ToString()
1889 << "\n";
1890 return Status::OK();
1891 } else {
1892 VLOG(1) << "Could not vectorize reduction " << reduce->ToString() << ": "
1893 << vectorization_failure_reason;
1894 }
1895 }
1896
1897 return DefaultAction(reduce);
1898 }
1899
HandleAllToAll(HloInstruction *)1900 Status IrEmitter::HandleAllToAll(HloInstruction*) {
1901 return Unimplemented("AllToAll is not implemented on CPU.");
1902 }
1903
HandleSend(HloInstruction * send)1904 Status IrEmitter::HandleSend(HloInstruction* send) {
1905 // TODO(b/33942983): Support Send/Recv on CPU.
1906 return Unimplemented("Send is not implemented on CPU.");
1907 }
1908
HandleSendDone(HloInstruction * send_done)1909 Status IrEmitter::HandleSendDone(HloInstruction* send_done) {
1910 // TODO(b/33942983): Support Send/Recv on CPU.
1911 return Unimplemented("Send-done is not implemented on CPU.");
1912 }
1913
HandleScatter(HloInstruction *)1914 Status IrEmitter::HandleScatter(HloInstruction*) {
1915 return Unimplemented("Scatter is not implemented on CPUs.");
1916 }
1917
HandleSlice(HloInstruction * slice)1918 Status IrEmitter::HandleSlice(HloInstruction* slice) {
1919 VLOG(2) << "HandleSlice: " << slice->ToString();
1920 auto operand = slice->operand(0);
1921 // The code below emits a sequential loop nest. For the parallel backend, use
1922 // ParallelLoopEmitter which respects dynamic loop bounds.
1923 if (ShouldEmitParallelLoopFor(*slice)) {
1924 return DefaultAction(slice);
1925 }
1926
1927 // The code below assumes the layouts are equal.
1928 if (!LayoutUtil::Equal(operand->shape().layout(), slice->shape().layout())) {
1929 return DefaultAction(slice);
1930 }
1931
1932 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(slice));
1933
1934 if (ShapeUtil::IsZeroElementArray(slice->shape())) {
1935 return Status::OK();
1936 }
1937
1938 const Layout& layout = operand->shape().layout();
1939 const int64 num_dims = operand->shape().dimensions_size();
1940
1941 // The slice lowering finds maximal contiguous blocks of memory that can be
1942 // copied from the source to the target. This is done by looking at the
1943 // source/target layout in minor to major order and do the following:
1944 //
1945 // * Find an initial segment of dimensions along which the slice uses the
1946 // whole dimension. These are the "inner" dimensions and can be folded into
1947 // the memcpy.
1948 //
1949 // * Of the remaining dimensions decide which ones require loops.
1950 //
1951 // * Implement the memcpy within the innermost loop.
1952
1953 absl::flat_hash_set<int64> inner_dims;
1954 for (int64 dim : LayoutUtil::MinorToMajor(layout)) {
1955 if (operand->shape().dimensions(dim) != slice->shape().dimensions(dim)) {
1956 break;
1957 }
1958 inner_dims.insert(dim);
1959 }
1960
1961 const bool is_trivial_copy = (inner_dims.size() == num_dims);
1962 if (is_trivial_copy) {
1963 if (ShapeUtil::IsEffectiveScalar(slice->shape())) {
1964 return DefaultAction(slice);
1965 } else {
1966 return EmitMemcpy(*slice, *operand);
1967 }
1968 }
1969
1970 // The memcpy will copy elements that are logically this shape (allowed to be
1971 // scalar).
1972 const Shape logical_element_shape = ShapeUtil::FilterDimensions(
1973 [&inner_dims](int64 dim) { return inner_dims.contains(dim); },
1974 operand->shape());
1975
1976 const int64 primitive_elements_per_logical_element =
1977 ShapeUtil::ElementsIn(logical_element_shape);
1978
1979 // memcpy_dim is the innermost (in terms of layout) dimension for which the
1980 // slice does *not* just copy all the elements along the dimension.
1981 const int64 memcpy_dim = LayoutUtil::Minor(layout, inner_dims.size());
1982
1983 const bool memcpy_is_contiguous = slice->slice_strides(memcpy_dim) == 1;
1984 // The number of logical elements that can be copied in a single call
1985 // to memcpy. We can only copy 1 element at a time if there is a non-trivial
1986 // stride.
1987 const int64 memcpy_logical_elements =
1988 memcpy_is_contiguous
1989 ? slice->slice_limits(memcpy_dim) - slice->slice_starts(memcpy_dim)
1990 : 1;
1991
1992 // Determine the dimensions that get lowered as loops.
1993 std::vector<int64> outer_dims;
1994 for (int64 i = 0; i < num_dims - inner_dims.size() - 1; ++i) {
1995 outer_dims.push_back(LayoutUtil::Major(layout, i));
1996 }
1997
1998 // Is the slice along the memcpy dimension contiguous? If not, then memcpy_dim
1999 // needs to be wrapped around a loop as well.
2000 if (!memcpy_is_contiguous) {
2001 outer_dims.push_back(memcpy_dim);
2002 }
2003
2004 llvm_ir::IrArray target_array = GetIrArrayFor(slice);
2005
2006 const int64 num_outer_loops = outer_dims.size();
2007 llvm_ir::ForLoopNest loops(IrName(slice), &b_);
2008 std::vector<llvm::Value*> target_multi_index =
2009 loops.AddLoopsForShapeOnDimensions(slice->shape(), outer_dims, "slice");
2010
2011 // Only the indices for the outer dimensions have been initialized in
2012 // target_index. The rest of the indices should get initialized to 0, since
2013 // for the rest of the dimensions the copy writes to the full dimension.
2014 std::replace(target_multi_index.begin(), target_multi_index.end(),
2015 static_cast<llvm::Value*>(nullptr),
2016 static_cast<llvm::Value*>(b_.getInt64(0)));
2017 llvm_ir::IrArray::Index target_index(target_multi_index, slice->shape(),
2018 b_.getInt64Ty());
2019
2020 if (num_outer_loops > 0) {
2021 SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_);
2022 }
2023
2024 llvm_ir::IrArray source_array = GetIrArrayFor(operand);
2025 const llvm_ir::IrArray::Index source_index = target_index.SourceIndexOfSlice(
2026 /*operand_shape=*/operand->shape(), /*starts=*/slice->slice_starts(),
2027 /*strides=*/slice->slice_strides(), /*builder=*/&b_);
2028
2029 llvm::Value* memcpy_dest =
2030 target_array.EmitArrayElementAddress(target_index, &b_, "slice.dest");
2031 llvm::Value* memcpy_source =
2032 source_array.EmitArrayElementAddress(source_index, &b_, "slice.source");
2033
2034 const int64 memcpy_elements =
2035 primitive_elements_per_logical_element * memcpy_logical_elements;
2036
2037 EmitTransferElements(memcpy_dest, memcpy_source, memcpy_elements,
2038 slice->shape().element_type(), target_array,
2039 source_array);
2040
2041 if (VLOG_IS_ON(2)) {
2042 const int64 memcpy_bytes =
2043 ShapeUtil::ByteSizeOf(logical_element_shape) * memcpy_elements;
2044 VLOG(2) << " emitted copy of " << memcpy_bytes << " bytes inside "
2045 << num_outer_loops << " loops";
2046 }
2047
2048 if (num_outer_loops > 0) {
2049 SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
2050 }
2051
2052 return Status::OK();
2053 }
2054
HandleDynamicSlice(HloInstruction * dynamic_slice)2055 Status IrEmitter::HandleDynamicSlice(HloInstruction* dynamic_slice) {
2056 if (ShapeUtil::IsScalar(dynamic_slice->shape())) {
2057 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(dynamic_slice));
2058 return EmitMemcpy(*dynamic_slice->operand(0), *dynamic_slice);
2059 }
2060 return DefaultAction(dynamic_slice);
2061 }
2062
HandleDynamicUpdateSlice(HloInstruction * dynamic_update_slice)2063 Status IrEmitter::HandleDynamicUpdateSlice(
2064 HloInstruction* dynamic_update_slice) {
2065 auto update = dynamic_update_slice->operand(1);
2066 if (ShapeUtil::IsScalar(dynamic_update_slice->shape())) {
2067 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(dynamic_update_slice));
2068 return EmitMemcpy(*update, *dynamic_update_slice);
2069 } else if (llvm_ir::CanUpdateDynamicSliceInPlace(dynamic_update_slice,
2070 assignment_)) {
2071 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(dynamic_update_slice));
2072 auto operands = GetIrArraysForOperandsOf(dynamic_update_slice);
2073 return llvm_ir::EmitDynamicUpdateSliceInPlace(
2074 operands, GetIrArrayFor(dynamic_update_slice),
2075 IrName(dynamic_update_slice, "in_place"), &b_);
2076 }
2077 return DefaultAction(dynamic_update_slice);
2078 }
2079
HandleRecv(HloInstruction * recv)2080 Status IrEmitter::HandleRecv(HloInstruction* recv) {
2081 // TODO(b/33942983): Support Send/Recv on CPU.
2082 return Unimplemented("Recv is not implemented on CPU.");
2083 }
2084
HandleRecvDone(HloInstruction * recv_done)2085 Status IrEmitter::HandleRecvDone(HloInstruction* recv_done) {
2086 // TODO(b/33942983): Support Send/Recv on CPU.
2087 return Unimplemented("Recv-done is not implemented on CPU.");
2088 }
2089
HandlePad(HloInstruction * pad)2090 Status IrEmitter::HandlePad(HloInstruction* pad) {
2091 // CPU backend does not properly handle negative padding but this is ok
2092 // because negative padding should be removed by the algebraic simplifier.
2093 for (auto& padding_dimension : pad->padding_config().dimensions()) {
2094 if (padding_dimension.edge_padding_low() < 0 ||
2095 padding_dimension.edge_padding_high() < 0) {
2096 return InternalErrorStrCat(
2097 "Encountered negative padding in IrEmitter on CPU. "
2098 "This should have been eliminated at the HLO level. ",
2099 pad->ToString());
2100 }
2101 }
2102
2103 // First, fill in the padding value to all output elements.
2104 TF_RETURN_IF_ERROR(EmitTargetElementLoop(
2105 pad, "initialize",
2106 [this, pad](const llvm_ir::IrArray::Index& target_index) {
2107 const HloInstruction* padding_value = pad->operand(1);
2108 llvm::Value* padding_value_addr = GetEmittedValueFor(padding_value);
2109 return Load(padding_value_addr);
2110 }));
2111
2112 // Create a loop to iterate over the operand elements and update the output
2113 // locations where the operand elements should be stored.
2114 llvm_ir::ForLoopNest loops(IrName(pad, "assign"), &b_);
2115 const HloInstruction* operand = pad->operand(0);
2116 const llvm_ir::IrArray::Index operand_index =
2117 loops.AddLoopsForShape(operand->shape(), "operand");
2118
2119 SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_);
2120
2121 // Load an element from the operand.
2122 llvm_ir::IrArray operand_array(GetIrArrayFor(operand));
2123 llvm::Value* operand_data =
2124 operand_array.EmitReadArrayElement(operand_index, &b_);
2125
2126 // Compute the output index the operand element should be assigned to.
2127 // output_index := edge_padding_low + operand_index * (interior_padding + 1)
2128 const PaddingConfig& padding_config = pad->padding_config();
2129 std::vector<llvm::Value*> output_multi_index;
2130 for (size_t i = 0; i < operand_index.size(); ++i) {
2131 llvm::Value* offset =
2132 Mul(operand_index[i],
2133 b_.getInt64(padding_config.dimensions(i).interior_padding() + 1));
2134 llvm::Value* index = Add(
2135 offset, b_.getInt64(padding_config.dimensions(i).edge_padding_low()));
2136 output_multi_index.push_back(index);
2137 }
2138
2139 // Store the operand element to the computed output location.
2140 llvm_ir::IrArray output_array(GetIrArrayFor(pad));
2141 llvm_ir::IrArray::Index output_index(
2142 output_multi_index, output_array.GetShape(), operand_index.GetType());
2143 output_array.EmitWriteArrayElement(output_index, operand_data, &b_);
2144
2145 SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
2146 return Status::OK();
2147 }
2148
HandleFusion(HloInstruction * fusion)2149 Status IrEmitter::HandleFusion(HloInstruction* fusion) {
2150 auto* root = fusion->fused_expression_root();
2151 if (llvm_ir::CanEmitFusedDynamicUpdateSliceInPlace(fusion, assignment_)) {
2152 VLOG(3) << "HandleFusion FusedDynamicUpdateSliceInPlace";
2153 CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_);
2154 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(fusion));
2155 // Delegate to common implementation of fused in-place dynamic-update-slice.
2156 return llvm_ir::EmitFusedDynamicUpdateSliceInPlace(
2157 fusion, GetGeneratorForOperandIrArrays(fusion), GetIrArrayFor(fusion),
2158 &elemental_emitter, &b_);
2159 } else if (fusion->fusion_kind() == HloInstruction::FusionKind::kLoop) {
2160 VLOG(3) << "HandleFusion kLoop";
2161 CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_);
2162 auto operands = GetIrArraysForOperandsOf(fusion);
2163 FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(fusion),
2164 &elemental_emitter);
2165 TF_RETURN_IF_ERROR(fusion->fused_expression_root()->Accept(&fused_emitter));
2166
2167 return EmitTargetElementLoop(fusion, fused_emitter.GetRootGenerator());
2168 } else if (fusion->fusion_kind() == HloInstruction::FusionKind::kOutput) {
2169 VLOG(3) << "HandleFusion kOutput";
2170 int64 dot_op_index = root->operand(0)->opcode() == HloOpcode::kDot ? 0 : 1;
2171 const HloInstruction* dot = root->operand(dot_op_index);
2172 CHECK_EQ(dot->opcode(), HloOpcode::kDot)
2173 << dot->ToString() << " "
2174 << fusion->fused_instructions_computation()->ToString();
2175
2176 int64 dot_lhs_param_number = dot->operand(0)->parameter_number();
2177 int64 dot_rhs_param_number = dot->operand(1)->parameter_number();
2178 int64 addend_param_number =
2179 root->operand(1 - dot_op_index)->parameter_number();
2180
2181 Shape target_shape = fusion->shape();
2182 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(fusion));
2183 llvm_ir::IrArray target_array = GetIrArrayFor(fusion);
2184
2185 llvm_ir::IrArray lhs_array(
2186 GetIrArrayFor(fusion->operand(dot_lhs_param_number)));
2187 llvm_ir::IrArray rhs_array(
2188 GetIrArrayFor(fusion->operand(dot_rhs_param_number)));
2189 llvm_ir::IrArray addend_array(
2190 GetIrArrayFor(fusion->operand(addend_param_number)));
2191
2192 TF_RETURN_IF_ERROR(
2193 EmitDotOperation(*dot, target_array, lhs_array, rhs_array,
2194 &addend_array, GetExecutableRunOptionsArgument(), &b_,
2195 hlo_module_config_, target_machine_features_));
2196 return Status::OK();
2197 } else {
2198 return Unimplemented("Fusion kind not implemented on CPU");
2199 }
2200 }
2201
HandleCall(HloInstruction * call)2202 Status IrEmitter::HandleCall(HloInstruction* call) {
2203 HloComputation* computation = call->to_apply();
2204 llvm::Function* call_ir_function = FindOrDie(emitted_functions_, computation);
2205
2206 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(call));
2207
2208 if (!computation->root_instruction()->outer_dimension_partitions().empty()) {
2209 // ParallelTaskAssignment assigned partitions, emit call to
2210 // ParallelForkJoin.
2211 std::vector<llvm::Value*> call_args = GetArrayFunctionCallArguments(
2212 {}, &b_, computation->name(),
2213 /*return_value_buffer=*/emitted_value_[call],
2214 /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(),
2215 /*buffer_table_arg=*/GetBufferTableArgument(),
2216 /*profile_counters_arg=*/GetProfileCountersArgument());
2217
2218 HloInstruction* root = computation->root_instruction();
2219 TF_RETURN_IF_ERROR(EmitCallToParallelForkJoin(
2220 call_args, root->shape(), root->outer_dimension_partitions(), &b_,
2221 call_ir_function, computation->name()));
2222 } else {
2223 EmitGlobalCall(*computation, computation->name());
2224 }
2225
2226 return Status::OK();
2227 }
2228
HandleCustomCall(HloInstruction * custom_call)2229 Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) {
2230 absl::Span<HloInstruction* const> operands(custom_call->operands());
2231 llvm::Type* i8_ptr_type = b_.getInt8PtrTy();
2232 llvm::AllocaInst* operands_alloca =
2233 llvm_ir::EmitAllocaAtFunctionEntryWithCount(
2234 i8_ptr_type, b_.getInt32(operands.size()), "cc_operands_alloca", &b_);
2235 for (size_t i = 0; i < operands.size(); ++i) {
2236 const HloInstruction* operand = operands[i];
2237 llvm::Value* operand_as_i8ptr =
2238 PointerCast(GetEmittedValueFor(operand), i8_ptr_type);
2239 llvm::Value* slot_in_operands_alloca =
2240 InBoundsGEP(operands_alloca, {b_.getInt64(i)});
2241 Store(operand_as_i8ptr, slot_in_operands_alloca);
2242 }
2243 if (emit_code_for_msan_) {
2244 // Mark the alloca as initialized for msan. The buffer gets read by the
2245 // custom callee, which might be msan-instrumented.
2246 // TODO(b/66051036): Run the msan instrumentation pass instead.
2247 const llvm::DataLayout& dl = module_->getDataLayout();
2248 llvm::Type* intptr_type = b_.getIntPtrTy(dl);
2249 auto* msan_unpoison_ir_function = llvm::cast<llvm::Function>(
2250 module_
2251 ->getOrInsertFunction(
2252 "__msan_unpoison",
2253 llvm::FunctionType::get(
2254 /*Result=*/b_.getVoidTy(),
2255 /*Params=*/{i8_ptr_type, intptr_type}, /*isVarArg=*/false))
2256 .getCallee());
2257 Call(msan_unpoison_ir_function,
2258 {PointerCast(operands_alloca, i8_ptr_type),
2259 llvm::ConstantInt::get(
2260 intptr_type, *operands_alloca->getAllocationSizeInBits(dl) / 8)});
2261 }
2262 auto* custom_call_ir_function = llvm::dyn_cast<llvm::Function>(
2263 module_
2264 ->getOrInsertFunction(
2265 custom_call->custom_call_target(),
2266 llvm::FunctionType::get(
2267 /*Result=*/b_.getVoidTy(),
2268 /*Params=*/{i8_ptr_type, operands_alloca->getType()},
2269 /*isVarArg=*/false))
2270 .getCallee());
2271
2272 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(custom_call));
2273 // Write the tuple table if the output is a tuple.
2274 if (custom_call->shape().IsTuple()) {
2275 std::vector<llvm::Value*> base_ptrs;
2276 for (int i = 0; i < ShapeUtil::TupleElementCount(custom_call->shape());
2277 ++i) {
2278 const Shape& elem_shape =
2279 ShapeUtil::GetTupleElementShape(custom_call->shape(), i);
2280 TF_RET_CHECK(!elem_shape.IsTuple()) << "Nested tuples not implemented";
2281 TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice,
2282 assignment_.GetUniqueSlice(custom_call, {i}));
2283 llvm::Value* addr = EmitBufferPointer(slice, elem_shape);
2284 base_ptrs.push_back(addr);
2285 }
2286 llvm_ir::EmitTuple(GetIrArrayFor(custom_call), base_ptrs, &b_);
2287 }
2288 auto* output_address_arg =
2289 PointerCast(GetEmittedValueFor(custom_call), i8_ptr_type);
2290
2291 Call(custom_call_ir_function, {output_address_arg, operands_alloca});
2292
2293 return Status::OK();
2294 }
2295
HandleWhile(HloInstruction * xla_while)2296 Status IrEmitter::HandleWhile(HloInstruction* xla_while) {
2297 // Precondition: Condition computation must return a scalar bool.
2298 HloComputation* condition = xla_while->while_condition();
2299 TF_RET_CHECK(ShapeUtil::IsScalar(condition->root_instruction()->shape()) &&
2300 condition->root_instruction()->shape().element_type() == PRED)
2301 << "While condition computation must return bool; got: "
2302 << ShapeUtil::HumanString(condition->root_instruction()->shape());
2303 // Check that all while-related buffers share an allocation slice.
2304 TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
2305 xla_while->shape(),
2306 [this, &xla_while](const Shape& /*subshape*/,
2307 const ShapeIndex& index) -> Status {
2308 auto check = [this](const HloInstruction* a, const HloInstruction* b,
2309 const ShapeIndex& index) {
2310 const BufferAllocation::Slice slice_a =
2311 assignment_.GetUniqueSlice(a, index).ConsumeValueOrDie();
2312 const BufferAllocation::Slice slice_b =
2313 assignment_.GetUniqueSlice(b, index).ConsumeValueOrDie();
2314 if (slice_a != slice_b) {
2315 return InternalError(
2316 "instruction %s %s does not share slice with "
2317 "instruction %s %s",
2318 a->ToString(), slice_a.ToString(), b->ToString(),
2319 slice_b.ToString());
2320 }
2321 return Status::OK();
2322 };
2323 TF_RETURN_IF_ERROR(check(xla_while, xla_while->operand(0), index));
2324 TF_RETURN_IF_ERROR(check(
2325 xla_while, xla_while->while_condition()->parameter_instruction(0),
2326 index));
2327 TF_RETURN_IF_ERROR(
2328 check(xla_while, xla_while->while_body()->parameter_instruction(0),
2329 index));
2330 TF_RETURN_IF_ERROR(check(
2331 xla_while, xla_while->while_body()->root_instruction(), index));
2332 return Status::OK();
2333 }));
2334
2335 // Set emitted value to that of 'init' with which it shares an allocation.
2336 const HloInstruction* init = xla_while->operand(0);
2337 emitted_value_[xla_while] = GetEmittedValueFor(init);
2338
2339 // Generating:
2340 // while (Condition(while_result)) {
2341 // // CopyInsertion pass inserts copies which enable 'while_result' to
2342 // // be passed back in as 'Body' parameter.
2343 // while_result = Body(while_result); // Insert
2344 // }
2345
2346 // Terminates the current block with a branch to a while header.
2347 llvm::BasicBlock* header_bb = llvm::BasicBlock::Create(
2348 module_->getContext(), IrName(xla_while, "header"),
2349 compute_function_->function());
2350 Br(header_bb);
2351 b_.SetInsertPoint(header_bb);
2352
2353 // Calls the condition function to determine whether to proceed with the
2354 // body. It must return a bool, so use the scalar call form.
2355 EmitGlobalCall(*xla_while->while_condition(), IrName(xla_while, "cond"));
2356 llvm::Value* while_predicate = ICmpNE(
2357 Load(GetBufferForGlobalCallReturnValue(*xla_while->while_condition())),
2358 llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0));
2359
2360 // Branches to the body or to the while exit depending on the condition.
2361 llvm::BasicBlock* body_bb =
2362 llvm::BasicBlock::Create(module_->getContext(), IrName(xla_while, "body"),
2363 compute_function_->function());
2364 llvm::BasicBlock* exit_bb = llvm::BasicBlock::Create(
2365 module_->getContext(), IrName(xla_while, "exit"));
2366 CondBr(while_predicate, body_bb, exit_bb);
2367
2368 // Calls the body function from the body block.
2369 b_.SetInsertPoint(body_bb);
2370
2371 // Calls the body function.
2372 EmitGlobalCall(*xla_while->while_body(), IrName(xla_while, "body"));
2373
2374 // Finishes with a branch back to the header.
2375 Br(header_bb);
2376
2377 // Adds the exit block to the function and sets the insert point there.
2378 compute_function_->function()->getBasicBlockList().push_back(exit_bb);
2379 b_.SetInsertPoint(exit_bb);
2380
2381 return Status::OK();
2382 }
2383
EmitFastConcatenate(HloInstruction * concatenate,absl::Span<HloInstruction * const> operands,string * failure_reason)2384 StatusOr<bool> IrEmitter::EmitFastConcatenate(
2385 HloInstruction* concatenate, absl::Span<HloInstruction* const> operands,
2386 string* failure_reason) {
2387 if (ShouldEmitParallelLoopFor(*concatenate)) {
2388 *failure_reason =
2389 "cannot generate memcpy-based concat for the parallel CPU backend";
2390 return false;
2391 }
2392
2393 const Shape& output_shape = concatenate->shape();
2394 for (auto* op : operands) {
2395 if (!LayoutUtil::Equal(op->shape().layout(), output_shape.layout())) {
2396 *failure_reason = "operand has mismatching layouts";
2397 return false;
2398 }
2399 }
2400
2401 // We split the dimensions into three categories: the dimension over which we
2402 // are concatenating (concat_dim), the dimensions that are minor to it
2403 // (inner_dims) and the dimensions that are major to it (outer_dims).
2404
2405 int64 concat_dim = concatenate->dimensions(0);
2406 const Layout& output_layout = output_shape.layout();
2407 auto output_min2maj = LayoutUtil::MinorToMajor(output_layout);
2408 auto concat_dim_layout_itr = absl::c_find(output_min2maj, concat_dim);
2409
2410 std::vector<int64> inner_dims(output_min2maj.begin(), concat_dim_layout_itr);
2411 std::vector<int64> outer_dims(std::next(concat_dim_layout_itr),
2412 output_min2maj.end());
2413
2414 llvm::Type* i8_ptr_type = b_.getInt8PtrTy();
2415
2416 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(concatenate));
2417 llvm_ir::IrArray target_array = GetIrArrayFor(concatenate);
2418
2419 llvm_ir::ForLoopNest loops(IrName(concatenate), &b_);
2420 std::vector<llvm::Value*> target_multi_index =
2421 loops.AddLoopsForShapeOnDimensions(output_shape, outer_dims, "concat");
2422 std::replace(target_multi_index.begin(), target_multi_index.end(),
2423 static_cast<llvm::Value*>(nullptr),
2424 static_cast<llvm::Value*>(b_.getInt64(0)));
2425 llvm_ir::IrArray::Index target_index(target_multi_index, output_shape,
2426 b_.getInt64Ty());
2427
2428 if (!outer_dims.empty()) {
2429 SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_);
2430 }
2431
2432 PrimitiveType primitive_type = output_shape.element_type();
2433 unsigned primitive_type_size =
2434 ShapeUtil::ByteSizeOfPrimitiveType(primitive_type);
2435
2436 // Contiguous subregions from each operand to the concatenate contribute to a
2437 // contiguous subregion in the target buffer starting at target_region_begin.
2438 llvm::Value* target_region_begin = BitCast(
2439 target_array.EmitArrayElementAddress(target_index, &b_, "target_region"),
2440 i8_ptr_type);
2441 int64 byte_offset_into_target_region = 0;
2442
2443 int64 inner_dims_product =
2444 std::accumulate(inner_dims.begin(), inner_dims.end(), 1l,
2445 [&](int64 product, int64 inner_dim) {
2446 return product * output_shape.dimensions(inner_dim);
2447 });
2448
2449 // For each operand, emit a memcpy from the operand to the target of size
2450 // equal to the product of inner dimensions.
2451 for (HloInstruction* operand : operands) {
2452 const Shape& input_shape = operand->shape();
2453 llvm_ir::IrArray source_array = GetIrArrayFor(operand);
2454 llvm::Value* copy_source_address = BitCast(
2455 source_array.EmitArrayElementAddress(target_index, &b_, "src_addr"),
2456 i8_ptr_type);
2457
2458 llvm::Value* copy_target_address =
2459 GEP(target_region_begin, b_.getInt64(byte_offset_into_target_region));
2460
2461 EmitTransferElements(
2462 copy_target_address, copy_source_address,
2463 inner_dims_product * input_shape.dimensions(concat_dim), primitive_type,
2464 target_array, source_array);
2465
2466 byte_offset_into_target_region += inner_dims_product *
2467 input_shape.dimensions(concat_dim) *
2468 primitive_type_size;
2469 }
2470
2471 if (!outer_dims.empty()) {
2472 SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
2473 }
2474
2475 return true;
2476 }
2477
EmitTransferElements(llvm::Value * target,llvm::Value * source,int64 element_count,PrimitiveType primitive_type,const llvm_ir::IrArray & target_array,const llvm_ir::IrArray & source_array)2478 void IrEmitter::EmitTransferElements(llvm::Value* target, llvm::Value* source,
2479 int64 element_count,
2480 PrimitiveType primitive_type,
2481 const llvm_ir::IrArray& target_array,
2482 const llvm_ir::IrArray& source_array) {
2483 unsigned primitive_type_size =
2484 ShapeUtil::ByteSizeOfPrimitiveType(primitive_type);
2485 unsigned element_alignment = tensorflow::MathUtil::GCD<unsigned>(
2486 primitive_type_size, MinimumAlignmentForPrimitiveType(primitive_type));
2487 llvm::Type* primitive_ptr_type = llvm::PointerType::getUnqual(
2488 llvm_ir::PrimitiveTypeToIrType(primitive_type, module_));
2489
2490 if (element_count == 1) {
2491 auto* load_instruction =
2492 AlignedLoad(BitCast(source, primitive_ptr_type), element_alignment);
2493 source_array.AnnotateLoadStoreInstructionWithMetadata(load_instruction);
2494 auto* store_instruction =
2495 AlignedStore(load_instruction, BitCast(target, primitive_ptr_type),
2496 element_alignment);
2497 target_array.AnnotateLoadStoreInstructionWithMetadata(store_instruction);
2498 } else {
2499 auto* memcpy_instruction = MemCpy(
2500 target, /*DstAlign=*/element_alignment, source,
2501 /*SrcAlign=*/element_alignment, element_count * primitive_type_size);
2502
2503 // The memcpy does the load and the store internally. The aliasing related
2504 // metadata has to reflect that.
2505 std::map<int, llvm::MDNode*> merged_metadata =
2506 llvm_ir::MergeMetadata(&module_->getContext(), source_array.metadata(),
2507 target_array.metadata());
2508 for (const auto& kind_md_pair : merged_metadata) {
2509 memcpy_instruction->setMetadata(kind_md_pair.first, kind_md_pair.second);
2510 }
2511 }
2512 }
2513
HandleConcatenate(HloInstruction * concatenate)2514 Status IrEmitter::HandleConcatenate(HloInstruction* concatenate) {
2515 absl::Span<HloInstruction* const> operands(concatenate->operands());
2516 string failure_reason;
2517 TF_ASSIGN_OR_RETURN(
2518 bool successful,
2519 EmitFastConcatenate(concatenate, operands, &failure_reason));
2520 if (successful) {
2521 VLOG(1) << "Emitted fast concatenate for " << concatenate->ToString();
2522 return Status::OK();
2523 }
2524
2525 VLOG(1) << "Could not emit fast concatenate for " << concatenate->ToString()
2526 << ": " << failure_reason;
2527
2528 return DefaultAction(concatenate);
2529 }
2530
HandleConditional(HloInstruction * conditional)2531 Status IrEmitter::HandleConditional(HloInstruction* conditional) {
2532 auto branch_index = conditional->operand(0);
2533 int num_branches = conditional->branch_count();
2534 TF_RET_CHECK(ShapeUtil::IsScalar(branch_index->shape()) &&
2535 (branch_index->shape().element_type() == PRED ||
2536 branch_index->shape().element_type() == S32))
2537 << "Branch index on a conditional must be scalar bool or int32; got: "
2538 << ShapeUtil::HumanString(branch_index->shape());
2539
2540 for (int b = 0; b < num_branches; ++b) {
2541 HloComputation* br_computation = conditional->branch_computation(b);
2542 TF_RET_CHECK(ShapeUtil::Equal(conditional->shape(),
2543 br_computation->root_instruction()->shape()))
2544 << "Shape of conditional should be same as the shape of the " << b
2545 << "th branch computation; got: "
2546 << ShapeUtil::HumanString(conditional->shape()) << " and "
2547 << ShapeUtil::HumanString(br_computation->root_instruction()->shape());
2548 }
2549
2550 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(conditional));
2551
2552 if (branch_index->shape().element_type() == PRED) {
2553 // Emit an if-else to LLVM:
2554 // if (pred)
2555 // cond_result = true_computation(true_operand)
2556 // else
2557 // cond_result = false_computation(false_operand)
2558 llvm::LoadInst* pred_value = Load(
2559 GetIrArrayFor(branch_index).GetBasePointer(), "load_predicate_value");
2560 llvm::Value* pred_cond =
2561 ICmpNE(pred_value,
2562 llvm::ConstantInt::get(
2563 llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0),
2564 "boolean_predicate");
2565 llvm_ir::LlvmIfData if_data =
2566 llvm_ir::EmitIfThenElse(pred_cond, "conditional", &b_);
2567
2568 SetToFirstInsertPoint(if_data.true_block, &b_);
2569 EmitGlobalCall(*conditional->branch_computation(0),
2570 IrName(conditional, "_true"));
2571
2572 SetToFirstInsertPoint(if_data.false_block, &b_);
2573 EmitGlobalCall(*conditional->branch_computation(1),
2574 IrName(conditional, "_false"));
2575
2576 SetToFirstInsertPoint(if_data.after_block, &b_);
2577 return Status::OK();
2578 }
2579 // We emit a switch statement to LLVM:
2580 // switch (branch_index) {
2581 // default:
2582 // result = branch_computations[num_branches-1](operands[num_branches-1]);
2583 // break;
2584 // case 0:
2585 // result = branch_computations[0](operands[0]); break;
2586 // case 1:
2587 // result = branch_computations[1](operands[1]); break;
2588 // ...
2589 // case [[num_branches-2]]:
2590 // result = branch_computations[num_branches-2](operands[num_branches-2]);
2591 // break;
2592 // }
2593 llvm::LoadInst* branch_index_value = Load(
2594 GetIrArrayFor(branch_index).GetBasePointer(), "load_branch_index_value");
2595
2596 auto case_block = b_.GetInsertBlock();
2597 llvm::BasicBlock* after_block;
2598 // Add a terminator to the case block, if necessary.
2599 if (case_block->getTerminator() == nullptr) {
2600 after_block = llvm_ir::CreateBasicBlock(nullptr, "case-after", &b_);
2601 b_.SetInsertPoint(case_block);
2602 b_.CreateBr(after_block);
2603 } else {
2604 after_block =
2605 case_block->splitBasicBlock(b_.GetInsertPoint(), "case-after");
2606 }
2607 // Our basic block should now end with an unconditional branch. Remove it;
2608 // we're going to replace it with a switch based branch.
2609 case_block->getTerminator()->eraseFromParent();
2610
2611 // Lower the default branch computation.
2612 auto default_block = llvm_ir::CreateBasicBlock(nullptr, "case-default", &b_);
2613 b_.SetInsertPoint(default_block);
2614 EmitGlobalCall(*conditional->branch_computation(num_branches - 1),
2615 IrName(conditional, "_default"));
2616 b_.CreateBr(after_block);
2617
2618 // Prepare the switch (branch_index) { ... } instruction.
2619 b_.SetInsertPoint(case_block);
2620 llvm::SwitchInst* case_inst =
2621 b_.CreateSwitch(branch_index_value, default_block, num_branches - 1);
2622 // Lower each branch's computation.
2623 for (int b = 0; b < num_branches - 1; ++b) { // last branch is default
2624 // Lower the case b: { ... ; break; } computation.
2625 auto branch_block =
2626 llvm_ir::CreateBasicBlock(nullptr, absl::StrCat("case-branch", b), &b_);
2627 b_.SetInsertPoint(branch_block);
2628 EmitGlobalCall(*conditional->branch_computation(b),
2629 IrName(conditional, absl::StrCat("_branch", b)));
2630 b_.CreateBr(after_block);
2631 case_inst->addCase(b_.getInt32(b), branch_block);
2632 }
2633
2634 SetToFirstInsertPoint(after_block, &b_);
2635 return Status::OK();
2636 }
2637
HandleAfterAll(HloInstruction * after_all)2638 Status IrEmitter::HandleAfterAll(HloInstruction* after_all) {
2639 TF_RET_CHECK(ByteSizeOf(after_all->shape()) == 0);
2640 // No code to generate, but we need to emit an address for book-keeping.
2641 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(after_all));
2642 return Status::OK();
2643 }
2644
HandleAddDependency(HloInstruction * add_dependency)2645 Status IrEmitter::HandleAddDependency(HloInstruction* add_dependency) {
2646 // AddDedendency just forwards its zero-th operand.
2647 emitted_value_[add_dependency] =
2648 GetEmittedValueFor(add_dependency->operand(0));
2649 return Status::OK();
2650 }
2651
HandleRng(HloInstruction * rng)2652 Status IrEmitter::HandleRng(HloInstruction* rng) {
2653 ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator;
2654 for (const HloInstruction* operand : rng->operands()) {
2655 operand_to_generator[operand] = [=](const llvm_ir::IrArray::Index& index) {
2656 return GetIrArrayFor(operand).EmitReadArrayElement(index, &b_);
2657 };
2658 }
2659
2660 CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_);
2661 TF_RETURN_IF_ERROR(EmitTargetElementLoop(
2662 rng, elemental_emitter.MakeElementGenerator(rng, operand_to_generator)));
2663
2664 llvm_ir::IncrementVariableForPhiloxRngState(1, module_, &b_);
2665
2666 return Status::OK();
2667 }
2668
FinishVisit(HloInstruction * root)2669 Status IrEmitter::FinishVisit(HloInstruction* root) {
2670 // When this method is called, we should have already emitted an IR value for
2671 // the root (return) op. The IR value holds the address of the buffer holding
2672 // the value. If the root is a constant or parameter, we perform a memcpy from
2673 // this buffer to the retval buffer of the computation. Otherwise, there's
2674 // nothing to do since the result was already written directly into the output
2675 // buffer.
2676 VLOG(2) << "FinishVisit root: " << root->ToString();
2677 if (root->opcode() == HloOpcode::kOutfeed) {
2678 VLOG(2) << " outfeed with value: "
2679 << llvm_ir::DumpToString(*GetEmittedValueFor(root->operand(0)));
2680 } else {
2681 VLOG(2) << " value: " << llvm_ir::DumpToString(*GetEmittedValueFor(root));
2682 }
2683
2684 auto record_complete_computation = [&](llvm::Value* prof_counter) {
2685 if (prof_counter) {
2686 profiling_state_.RecordCompleteComputation(&b_, prof_counter);
2687 }
2688 };
2689
2690 // For the entry computation this increment is cumulative of embedded
2691 // computations since it includes cycles spent in computations invoked by
2692 // While, Call etc.
2693 record_complete_computation(GetProfileCounterFor(*root->parent()));
2694 return Status::OK();
2695 }
2696
2697 template <typename T>
GetProfileCounterCommon(const T & hlo,const std::unordered_map<const T *,int64> & profile_index_map)2698 llvm::Value* IrEmitter::GetProfileCounterCommon(
2699 const T& hlo,
2700 const std::unordered_map<const T*, int64>& profile_index_map) {
2701 auto it = profile_index_map.find(&hlo);
2702 if (it == profile_index_map.end()) {
2703 return nullptr;
2704 }
2705
2706 int64 prof_counter_idx = it->second;
2707 string counter_name = IrName("prof_counter", hlo.name());
2708 return GEP(GetProfileCountersArgument(), b_.getInt64(prof_counter_idx),
2709 counter_name);
2710 }
2711
UpdateProfileCounter(llvm::IRBuilder<> * b,llvm::Value * prof_counter,llvm::Value * cycle_end,llvm::Value * cycle_start)2712 void IrEmitter::ProfilingState::UpdateProfileCounter(llvm::IRBuilder<>* b,
2713 llvm::Value* prof_counter,
2714 llvm::Value* cycle_end,
2715 llvm::Value* cycle_start) {
2716 auto* cycle_diff = b->CreateSub(cycle_end, cycle_start);
2717 llvm::LoadInst* old_cycle_count =
2718 b->CreateLoad(prof_counter, "old_cycle_count");
2719 auto* new_cycle_count =
2720 b->CreateAdd(cycle_diff, old_cycle_count, "new_cycle_count");
2721 b->CreateStore(new_cycle_count, prof_counter);
2722 }
2723
ReadCycleCounter(llvm::IRBuilder<> * b)2724 llvm::Value* IrEmitter::ProfilingState::ReadCycleCounter(llvm::IRBuilder<>* b) {
2725 llvm::Module* module = b->GetInsertBlock()->getModule();
2726 if (use_rdtscp_) {
2727 llvm::Function* func_llvm_readcyclecounter =
2728 llvm::Intrinsic::getDeclaration(module,
2729 llvm::Intrinsic::readcyclecounter);
2730 return b->CreateCall(func_llvm_readcyclecounter);
2731 }
2732 llvm::Function* func_llvm_x86_rdtscp =
2733 llvm::Intrinsic::getDeclaration(module, llvm::Intrinsic::x86_rdtscp);
2734 if (!aux_i8ptr_) {
2735 llvm::AllocaInst* rdtscp_aux =
2736 llvm_ir::EmitAllocaAtFunctionEntry(b->getInt32Ty(), "rdtscp_aux", b);
2737 aux_i8ptr_ = b->CreateBitCast(rdtscp_aux, b->getInt8PtrTy());
2738 }
2739 llvm::ConstantInt* alloca_size = b->getInt64(4);
2740 llvm::Function* func_llvm_lifetime_start =
2741 llvm::Intrinsic::getDeclaration(module, llvm::Intrinsic::lifetime_start);
2742 b->CreateCall(func_llvm_lifetime_start, {alloca_size, aux_i8ptr_});
2743 llvm::Value* rdtscp_call = b->CreateCall(func_llvm_x86_rdtscp, aux_i8ptr_);
2744 llvm::Function* func_llvm_lifetime_end =
2745 llvm::Intrinsic::getDeclaration(module, llvm::Intrinsic::lifetime_end);
2746 b->CreateCall(func_llvm_lifetime_end, {alloca_size, aux_i8ptr_});
2747 return rdtscp_call;
2748 }
2749
RecordCycleStart(llvm::IRBuilder<> * b,HloInstruction * hlo)2750 void IrEmitter::ProfilingState::RecordCycleStart(llvm::IRBuilder<>* b,
2751 HloInstruction* hlo) {
2752 auto* cycle_start = ReadCycleCounter(b);
2753 cycle_start->setName(IrName(hlo, "cycle_start"));
2754 cycle_starts_[hlo] = cycle_start;
2755 if (first_read_cycle_start_ == nullptr) {
2756 first_read_cycle_start_ = cycle_start;
2757 }
2758 }
2759
RecordCycleDelta(llvm::IRBuilder<> * b,HloInstruction * hlo,llvm::Value * prof_counter)2760 void IrEmitter::ProfilingState::RecordCycleDelta(llvm::IRBuilder<>* b,
2761 HloInstruction* hlo,
2762 llvm::Value* prof_counter) {
2763 auto* cycle_end = ReadCycleCounter(b);
2764 cycle_end->setName(IrName(hlo, "cycle_end"));
2765 auto* cycle_start = cycle_starts_[hlo];
2766 UpdateProfileCounter(b, prof_counter, cycle_end, cycle_start);
2767 last_read_cycle_end_ = cycle_end;
2768 }
2769
RecordCompleteComputation(llvm::IRBuilder<> * b,llvm::Value * prof_counter)2770 void IrEmitter::ProfilingState::RecordCompleteComputation(
2771 llvm::IRBuilder<>* b, llvm::Value* prof_counter) {
2772 if (last_read_cycle_end_ && first_read_cycle_start_) {
2773 UpdateProfileCounter(b, prof_counter, last_read_cycle_end_,
2774 first_read_cycle_start_);
2775 }
2776 }
2777
Preprocess(HloInstruction * hlo)2778 Status IrEmitter::Preprocess(HloInstruction* hlo) {
2779 VLOG(3) << "Visiting: " << hlo->ToString();
2780 if (instruction_to_profile_idx_.count(hlo)) {
2781 profiling_state_.RecordCycleStart(&b_, hlo);
2782 }
2783 return Status::OK();
2784 }
2785
Postprocess(HloInstruction * hlo)2786 Status IrEmitter::Postprocess(HloInstruction* hlo) {
2787 if (auto* prof_counter = GetProfileCounterFor(*hlo)) {
2788 profiling_state_.RecordCycleDelta(&b_, hlo, prof_counter);
2789 }
2790 return Status::OK();
2791 }
2792
GetIrArrayFor(const HloInstruction * hlo)2793 llvm_ir::IrArray IrEmitter::GetIrArrayFor(const HloInstruction* hlo) {
2794 llvm::Value* value_for_op = GetEmittedValueFor(hlo);
2795
2796 llvm_ir::IrArray array(value_for_op, hlo->shape());
2797 AddAliasingInformationToIrArray(*hlo, &array);
2798 return array;
2799 }
2800
GetIrArraysForOperandsOf(const HloInstruction * hlo)2801 std::vector<llvm_ir::IrArray> IrEmitter::GetIrArraysForOperandsOf(
2802 const HloInstruction* hlo) {
2803 std::vector<llvm_ir::IrArray> arrays;
2804 std::transform(
2805 hlo->operands().begin(), hlo->operands().end(),
2806 std::back_inserter(arrays),
2807 [&](const HloInstruction* operand) { return GetIrArrayFor(operand); });
2808 return arrays;
2809 }
2810
GetEmittedValueFor(const HloInstruction * hlo)2811 llvm::Value* IrEmitter::GetEmittedValueFor(const HloInstruction* hlo) {
2812 auto it = emitted_value_.find(hlo);
2813 if (it == emitted_value_.end()) {
2814 LOG(FATAL) << "could not find emitted value for: " << hlo->ToString();
2815 }
2816 return it->second;
2817 }
2818
IrShapeType(const Shape & shape)2819 llvm::Type* IrEmitter::IrShapeType(const Shape& shape) {
2820 return llvm_ir::ShapeToIrType(shape, module_);
2821 }
2822
GetProfileCountersArgument()2823 llvm::Value* IrEmitter::GetProfileCountersArgument() {
2824 return compute_function_->profile_counters_arg();
2825 }
2826
GetBufferTableArgument()2827 llvm::Value* IrEmitter::GetBufferTableArgument() {
2828 return compute_function_->buffer_table_arg();
2829 }
2830
GetExecutableRunOptionsArgument()2831 llvm::Value* IrEmitter::GetExecutableRunOptionsArgument() {
2832 return compute_function_->exec_run_options_arg();
2833 }
2834
EmitThreadLocalBufferPointer(const BufferAllocation::Slice & slice,const Shape & target_shape)2835 llvm::Value* IrEmitter::EmitThreadLocalBufferPointer(
2836 const BufferAllocation::Slice& slice, const Shape& target_shape) {
2837 const BufferAllocation& allocation = *slice.allocation();
2838 llvm::Value* tempbuf_address = [&]() -> llvm::Value* {
2839 if (slice == computation_root_allocation_) {
2840 llvm::Argument* retval = compute_function_->result_arg();
2841 llvm::AttrBuilder attr_builder;
2842 attr_builder.addAlignmentAttr(MinimumAlignmentForShape(target_shape));
2843 attr_builder.addDereferenceableAttr(ByteSizeOf(target_shape));
2844 retval->addAttrs(attr_builder);
2845 return retval;
2846 }
2847
2848 auto param_it =
2849 computation_parameter_allocations_.find(slice.allocation()->index());
2850 if (param_it != computation_parameter_allocations_.end()) {
2851 int64 param_number = param_it->second;
2852 // We have to access the parameter at offset param_number in the params
2853 // array. The code generated here is equivalent to this C code:
2854 //
2855 // i8* param_address_untyped = params[param_number];
2856 // Param* param_address_typed = (Param*)param_address_untyped;
2857 //
2858 // Where Param is the actual element type of the underlying buffer (for
2859 // example, float for an XLA F32 element type).
2860 llvm::Value* params = compute_function_->parameters_arg();
2861 llvm::Value* param_address_offset =
2862 llvm_ir::EmitBufferIndexingGEP(params, param_number, &b_);
2863 llvm::LoadInst* param_address_untyped = Load(param_address_offset);
2864
2865 if (!target_shape.IsOpaque()) {
2866 AttachAlignmentMetadataForLoad(param_address_untyped, target_shape);
2867 AttachDereferenceableMetadataForLoad(param_address_untyped,
2868 target_shape);
2869 }
2870 return param_address_untyped;
2871 }
2872
2873 // Thread-local allocations should only be assigned a single buffer.
2874 const auto& assigned_buffers = allocation.assigned_buffers();
2875 CHECK_EQ(1, assigned_buffers.size());
2876 const Shape& shape = assigned_buffers.begin()->first->shape();
2877
2878 std::pair<llvm::Function*, BufferAllocation::Slice> key = {
2879 compute_function_->function(), slice};
2880 auto buf_it = thread_local_buffers_.find(key);
2881 if (buf_it == thread_local_buffers_.end()) {
2882 llvm::Value* buffer = llvm_ir::EmitAllocaAtFunctionEntry(
2883 IrShapeType(shape), absl::StrCat("thread_local", slice.ToString()),
2884 &b_, MinimumAlignmentForShape(target_shape));
2885 auto it_inserted_pair = thread_local_buffers_.insert({key, buffer});
2886 CHECK(it_inserted_pair.second);
2887 buf_it = it_inserted_pair.first;
2888 }
2889 return buf_it->second;
2890 }();
2891 return BitCast(tempbuf_address, IrShapeType(target_shape)->getPointerTo());
2892 }
2893
EmitGlobalBufferPointer(const BufferAllocation::Slice & slice,const Shape & target_shape)2894 llvm::Value* IrEmitter::EmitGlobalBufferPointer(
2895 const BufferAllocation::Slice& slice, const Shape& target_shape) {
2896 const BufferAllocation& allocation = *slice.allocation();
2897 llvm::Value* tempbuf_address_ptr = llvm_ir::EmitBufferIndexingGEP(
2898 GetBufferTableArgument(), slice.index(), &b_);
2899 llvm::LoadInst* tempbuf_address_base = Load(tempbuf_address_ptr);
2900 if (hlo_module_config_.debug_options()
2901 .xla_llvm_enable_invariant_load_metadata()) {
2902 tempbuf_address_base->setMetadata(
2903 llvm::LLVMContext::MD_invariant_load,
2904 llvm::MDNode::get(tempbuf_address_base->getContext(), /*MDs=*/{}));
2905 }
2906 AttachAlignmentMetadataForLoad(tempbuf_address_base, allocation.size());
2907 AttachDereferenceableMetadataForLoad(tempbuf_address_base, allocation.size());
2908
2909 llvm::Value* tempbuf_address_untyped = tempbuf_address_base;
2910 if (slice.offset() > 0) {
2911 // Adjust the address to account for the slice offset.
2912 tempbuf_address_untyped =
2913 InBoundsGEP(tempbuf_address_base, b_.getInt64(slice.offset()));
2914 }
2915 return BitCast(tempbuf_address_untyped,
2916 IrShapeType(target_shape)->getPointerTo());
2917 }
2918
EmitBufferPointer(const BufferAllocation::Slice & slice,const Shape & target_shape)2919 llvm::Value* IrEmitter::EmitBufferPointer(const BufferAllocation::Slice& slice,
2920 const Shape& target_shape) {
2921 if (slice.allocation()->is_thread_local()) {
2922 return EmitThreadLocalBufferPointer(slice, target_shape);
2923 } else if (slice.allocation()->is_constant()) {
2924 return BitCast(
2925 FindOrDie(constant_buffer_to_global_, slice.allocation()->index()),
2926 IrShapeType(target_shape)->getPointerTo());
2927 } else {
2928 return EmitGlobalBufferPointer(slice, target_shape);
2929 }
2930 }
2931
EmitTargetAddressForOp(const HloInstruction * op)2932 Status IrEmitter::EmitTargetAddressForOp(const HloInstruction* op) {
2933 const Shape& target_shape = op->shape();
2934 TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice,
2935 assignment_.GetUniqueTopLevelSlice(op));
2936 llvm::Value* addr = EmitBufferPointer(slice, target_shape);
2937 addr->setName(IrName(op));
2938 emitted_value_[op] = addr;
2939 return Status::OK();
2940 }
2941
EmitTargetElementLoop(HloInstruction * target_op,const llvm_ir::ElementGenerator & element_generator)2942 Status IrEmitter::EmitTargetElementLoop(
2943 HloInstruction* target_op,
2944 const llvm_ir::ElementGenerator& element_generator) {
2945 return EmitTargetElementLoop(target_op, /*desc=*/"", element_generator);
2946 }
2947
EmitTargetElementLoop(HloInstruction * target_op,absl::string_view desc,const llvm_ir::ElementGenerator & element_generator)2948 Status IrEmitter::EmitTargetElementLoop(
2949 HloInstruction* target_op, absl::string_view desc,
2950 const llvm_ir::ElementGenerator& element_generator) {
2951 VLOG(2) << "EmitTargetElementLoop: " << target_op->ToString();
2952
2953 const Shape& target_shape = target_op->shape();
2954 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(target_op));
2955 llvm_ir::IrArray target_array = GetIrArrayFor(target_op);
2956
2957 if (target_op->IsMultiOutputFusion()) {
2958 // For multiple outputs fusion, we need to emit each operand and the root.
2959 TF_RET_CHECK(num_dynamic_loop_bounds_ == 0);
2960 std::vector<llvm_ir::IrArray> output_arrays;
2961 for (int64 i = 0; i < ShapeUtil::TupleElementCount(target_shape); ++i) {
2962 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice,
2963 assignment_.GetUniqueSlice(target_op, {i}));
2964 const Shape& element_shape = ShapeUtil::GetSubshape(target_shape, {i});
2965 llvm::Value* op_target_address = EmitBufferPointer(slice, element_shape);
2966 output_arrays.push_back(
2967 llvm_ir::IrArray(op_target_address, element_shape));
2968 }
2969 TF_RETURN_IF_ERROR(
2970 llvm_ir::LoopEmitter(element_generator, output_arrays, &b_)
2971 .EmitLoop(IrName(target_op)));
2972
2973 std::vector<llvm::Value*> tuple_operand_ptrs;
2974 for (int64 i = 0; i < output_arrays.size(); ++i) {
2975 tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer());
2976 }
2977 llvm_ir::EmitTuple(target_array, tuple_operand_ptrs, &b_);
2978
2979 } else {
2980 if (ShouldEmitParallelLoopFor(*target_op)) {
2981 // Emit code to read dynamic loop bounds from compute function argument.
2982 std::vector<std::pair<llvm::Value*, llvm::Value*>> dynamic_loop_bounds =
2983 compute_function_->GetDynamicLoopBounds();
2984 // Emit parallel loop with dynamic loop bounds for most-major dimensions.
2985 TF_RETURN_IF_ERROR(ParallelLoopEmitter(element_generator, target_array,
2986 &dynamic_loop_bounds, &b_)
2987 .EmitLoop(IrName(target_op)));
2988 } else {
2989 TF_RETURN_IF_ERROR(
2990 llvm_ir::LoopEmitter(element_generator, target_array, &b_)
2991 .EmitLoop(IrName(target_op)));
2992 }
2993 }
2994 return Status::OK();
2995 }
2996
EmitMemcpy(const HloInstruction & source,const HloInstruction & destination)2997 Status IrEmitter::EmitMemcpy(const HloInstruction& source,
2998 const HloInstruction& destination) {
2999 llvm::Value* source_value = GetEmittedValueFor(&source);
3000 llvm::Value* destination_value = GetEmittedValueFor(&destination);
3001 int64 source_size = ByteSizeOf(source.shape());
3002 // TODO(b/63762267): Be more aggressive about specifying alignment.
3003 MemCpy(destination_value, /*DstAlign=*/1, source_value,
3004 /*SrcAlign=*/1, source_size);
3005 return Status::OK();
3006 }
3007
ElementTypesSameAndSupported(const HloInstruction & instruction,absl::Span<const HloInstruction * const> operands,absl::Span<const PrimitiveType> supported_types)3008 Status IrEmitter::ElementTypesSameAndSupported(
3009 const HloInstruction& instruction,
3010 absl::Span<const HloInstruction* const> operands,
3011 absl::Span<const PrimitiveType> supported_types) {
3012 for (auto operand : operands) {
3013 TF_RET_CHECK(
3014 ShapeUtil::SameElementType(operands[0]->shape(), operand->shape()));
3015 }
3016
3017 TF_RET_CHECK(!operands.empty());
3018 PrimitiveType primitive_type = operands[0]->shape().element_type();
3019 if (!absl::c_linear_search(supported_types, primitive_type)) {
3020 return Unimplemented("unsupported operand type %s in op %s",
3021 PrimitiveType_Name(primitive_type),
3022 HloOpcodeString(instruction.opcode()));
3023 }
3024 return Status::OK();
3025 }
3026
DefaultAction(HloInstruction * hlo)3027 Status IrEmitter::DefaultAction(HloInstruction* hlo) {
3028 ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator;
3029 for (const HloInstruction* operand : hlo->operands()) {
3030 operand_to_generator[operand] = [=](const llvm_ir::IrArray::Index& index) {
3031 return GetIrArrayFor(operand).EmitReadArrayElement(index, &b_);
3032 };
3033 }
3034 CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_);
3035 return EmitTargetElementLoop(
3036 hlo, elemental_emitter.MakeElementGenerator(hlo, operand_to_generator));
3037 }
3038
EmitThreadLocalCall(const HloComputation & callee,absl::Span<llvm::Value * const> parameters,absl::string_view name)3039 llvm::Value* IrEmitter::EmitThreadLocalCall(
3040 const HloComputation& callee, absl::Span<llvm::Value* const> parameters,
3041 absl::string_view name) {
3042 CHECK(absl::c_binary_search(thread_local_computations_, &callee));
3043
3044 const Shape& return_shape = callee.root_instruction()->shape();
3045
3046 // Lifting this restriction to allow "small" arrays should be easy. Allowing
3047 // larger arrays is difficult because we allocate the buffer for this return
3048 // value on the stack.
3049 CHECK(ShapeUtil::IsScalar(return_shape));
3050
3051 PrimitiveType return_type = return_shape.element_type();
3052
3053 std::vector<llvm::Value*> parameter_addrs;
3054 for (llvm::Value* parameter : parameters) {
3055 CHECK(!parameter->getType()->isPointerTy());
3056 llvm::Value* parameter_addr = llvm_ir::EmitAllocaAtFunctionEntry(
3057 parameter->getType(), "arg_addr", &b_);
3058 Store(parameter, parameter_addr);
3059 parameter_addrs.push_back(parameter_addr);
3060 }
3061
3062 llvm::Value* return_value_buffer = llvm_ir::EmitAllocaAtFunctionEntry(
3063 llvm_ir::PrimitiveTypeToIrType(return_type, module_),
3064 absl::StrCat(name, "_retval_addr"), &b_,
3065 MinimumAlignmentForPrimitiveType(return_type));
3066
3067 Call(FindOrDie(emitted_functions_, &callee),
3068 GetArrayFunctionCallArguments(
3069 parameter_addrs, &b_, name,
3070 /*return_value_buffer=*/return_value_buffer,
3071 /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(),
3072 /*buffer_table_arg=*/
3073 llvm::Constant::getNullValue(b_.getInt8PtrTy()->getPointerTo()),
3074 /*profile_counters_arg=*/GetProfileCountersArgument()));
3075
3076 return Load(return_value_buffer);
3077 }
3078
EmitGlobalCall(const HloComputation & callee,absl::string_view name)3079 void IrEmitter::EmitGlobalCall(const HloComputation& callee,
3080 absl::string_view name) {
3081 CHECK(absl::c_binary_search(global_computations_, &callee));
3082
3083 Call(FindOrDie(emitted_functions_, &callee),
3084 GetArrayFunctionCallArguments(
3085 /*parameter_addresses=*/{}, &b_, name,
3086 /*return_value_buffer=*/
3087 llvm::Constant::getNullValue(b_.getInt8PtrTy()),
3088 /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(),
3089 /*buffer_table_arg=*/GetBufferTableArgument(),
3090 /*profile_counters_arg=*/GetProfileCountersArgument()));
3091 }
3092
GetBufferForGlobalCallReturnValue(const HloComputation & callee)3093 llvm::Value* IrEmitter::GetBufferForGlobalCallReturnValue(
3094 const HloComputation& callee) {
3095 const HloInstruction* root_inst = callee.root_instruction();
3096 if (root_inst->opcode() == HloOpcode::kOutfeed) {
3097 return llvm::Constant::getNullValue(b_.getInt8PtrTy());
3098 }
3099
3100 const BufferAllocation::Slice root_buffer =
3101 assignment_.GetUniqueTopLevelSlice(root_inst).ValueOrDie();
3102 return EmitBufferPointer(root_buffer, root_inst->shape());
3103 }
3104
3105 } // namespace cpu
3106 } // namespace xla
3107