1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/cpu/ir_emitter.h"
17
18 #include <stddef.h>
19 #include <stdint.h>
20
21 #include <algorithm>
22 #include <iterator>
23 #include <limits>
24 #include <memory>
25 #include <utility>
26 #include <vector>
27
28 // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
29 #include "absl/container/flat_hash_map.h"
30 #include "absl/container/flat_hash_set.h"
31 #include "absl/strings/str_cat.h"
32 #include "absl/strings/str_format.h"
33 #include "absl/strings/str_join.h"
34 #include "absl/types/span.h"
35 #include "llvm/CodeGen/TargetRegisterInfo.h"
36 #include "llvm/CodeGen/TargetSubtargetInfo.h"
37 #include "llvm/IR/BasicBlock.h"
38 #include "llvm/IR/Constants.h"
39 #include "llvm/IR/GlobalVariable.h"
40 #include "llvm/IR/Instructions.h"
41 #include "llvm/IR/Intrinsics.h"
42 #include "llvm/IR/IntrinsicsX86.h"
43 #include "llvm/IR/LLVMContext.h"
44 #include "llvm/IR/Value.h"
45 #include "tensorflow/compiler/xla/layout_util.h"
46 #include "tensorflow/compiler/xla/map_util.h"
47 #include "tensorflow/compiler/xla/primitive_util.h"
48 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
49 #include "tensorflow/compiler/xla/service/collective_ops_utils.h"
50 #include "tensorflow/compiler/xla/service/cpu/cpu_options.h"
51 #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
52 #include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h"
53 #include "tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h"
54 #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
55 #include "tensorflow/compiler/xla/service/cpu/ir_function.h"
56 #include "tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h"
57 #include "tensorflow/compiler/xla/service/cpu/shape_partition.h"
58 #include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h"
59 #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
60 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
61 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
62 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
63 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
64 #include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
65 #include "tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h"
66 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
67 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
68 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
69 #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
70 #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h"
71 #include "tensorflow/compiler/xla/shape_util.h"
72 #include "tensorflow/compiler/xla/status_macros.h"
73 #include "tensorflow/compiler/xla/types.h"
74 #include "tensorflow/compiler/xla/util.h"
75 #include "tensorflow/compiler/xla/window_util.h"
76 #include "tensorflow/compiler/xla/xla_data.pb.h"
77 #include "tensorflow/core/lib/core/bits.h"
78 #include "tensorflow/core/lib/core/errors.h"
79 #include "tensorflow/core/lib/math/math_util.h"
80 #include "tensorflow/core/platform/logging.h"
81
82 namespace xla {
83
84 namespace {
85 using llvm_ir::IrName;
86 using llvm_ir::SetToFirstInsertPoint;
87 } // namespace
88
89 namespace cpu {
90
IrEmitter(mlir::MLIRContext * mlir_context,const HloModule & hlo_module,const BufferAssignment & assignment,llvm::Module * llvm_module,std::unordered_map<const HloInstruction *,int64> instruction_to_profile_idx,std::unordered_map<const HloComputation *,int64> computation_to_profile_idx,const TargetMachineFeatures * target_machine_features,bool emit_code_for_msan)91 IrEmitter::IrEmitter(
92 mlir::MLIRContext* mlir_context, const HloModule& hlo_module,
93 const BufferAssignment& assignment, llvm::Module* llvm_module,
94 std::unordered_map<const HloInstruction*, int64> instruction_to_profile_idx,
95 std::unordered_map<const HloComputation*, int64> computation_to_profile_idx,
96 const TargetMachineFeatures* target_machine_features,
97 bool emit_code_for_msan)
98 : assignment_(assignment),
99 module_(llvm_module),
100 arch_type_(llvm::Triple(llvm_module->getTargetTriple()).getArch()),
101 b_(llvm_module->getContext()),
102 mlir_context_(mlir_context),
103 instruction_to_profile_idx_(std::move(instruction_to_profile_idx)),
104 computation_to_profile_idx_(std::move(computation_to_profile_idx)),
105 alias_analysis_(hlo_module, assignment, &llvm_module->getContext()),
106 hlo_module_config_(hlo_module.config()),
107 is_top_level_computation_(false),
108 target_machine_features_(*target_machine_features),
109 emit_code_for_msan_(emit_code_for_msan) {
110 b_.setFastMathFlags(llvm_ir::GetCpuFastMathFlags(hlo_module_config_));
111 Status s = GatherComputationsByAllocationType(
112 &hlo_module, &thread_local_computations_, &global_computations_);
113 absl::c_sort(thread_local_computations_);
114 absl::c_sort(global_computations_);
115 TF_CHECK_OK(s) << "Should have failed buffer assignment.";
116 }
117
EmitThreadLocalFunctionEpilogue(HloComputation * computation)118 void IrEmitter::EmitThreadLocalFunctionEpilogue(HloComputation* computation) {
119 llvm::Argument* out_parameter = compute_function_->result_arg();
120 llvm_ir::IrArray root_value = GetIrArrayFor(computation->root_instruction());
121 const Shape& return_shape = computation->root_instruction()->shape();
122
123 if (ShapeUtil::IsScalar(return_shape)) {
124 llvm::Value* ret_value =
125 Load(root_value.GetBasePointer(), "load_ret_value");
126 Store(ret_value,
127 BitCast(out_parameter, root_value.GetBasePointer()->getType()));
128 } else {
129 CHECK(return_shape.IsTuple());
130
131 llvm::Type* tuple_type = llvm_ir::ShapeToIrType(return_shape, module_);
132 llvm::Type* tuple_type_lvalue = tuple_type->getPointerTo();
133 llvm::Value* tuple_lvalue = BitCast(out_parameter, tuple_type_lvalue);
134
135 for (int i = 0; i < return_shape.tuple_shapes_size(); i++) {
136 const Shape& element_shape = return_shape.tuple_shapes(i);
137 llvm::Value* destination = llvm_ir::EmitGetTupleElement(
138 element_shape,
139 /*index=*/i,
140 /*alignment=*/MinimumAlignmentForShape(element_shape), tuple_lvalue,
141 &b_);
142
143 llvm::Value* source = llvm_ir::EmitGetTupleElement(
144 element_shape,
145 /*index=*/i,
146 /*alignment=*/MinimumAlignmentForShape(element_shape),
147 root_value.GetBasePointer(), &b_);
148
149 Store(Load(source), destination);
150 }
151 }
152 }
153
EmitComputation(HloComputation * computation,const string & function_name_prefix,bool is_top_level_computation,absl::Span<HloInstruction * const> instruction_order)154 StatusOr<llvm::Function*> IrEmitter::EmitComputation(
155 HloComputation* computation, const string& function_name_prefix,
156 bool is_top_level_computation,
157 absl::Span<HloInstruction* const> instruction_order) {
158 string function_name = name_uniquer_.GetUniqueName(function_name_prefix);
159 VLOG(2) << "Emitting IR for CPU function [" << function_name_prefix << "]";
160 is_top_level_computation_ = is_top_level_computation;
161 num_dynamic_loop_bounds_ = 0;
162 if (!computation->root_instruction()->outer_dimension_partitions().empty()) {
163 num_dynamic_loop_bounds_ =
164 computation->root_instruction()->outer_dimension_partitions().size();
165 }
166
167 if (computation->root_instruction()->opcode() != HloOpcode::kOutfeed) {
168 TF_ASSIGN_OR_RETURN(
169 computation_root_allocation_,
170 assignment_.GetUniqueTopLevelSlice(computation->root_instruction()));
171 }
172
173 for (const HloInstruction* param : computation->parameter_instructions()) {
174 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice param_slice,
175 assignment_.GetUniqueTopLevelSlice(param));
176 computation_parameter_allocations_[param_slice.allocation()->index()] =
177 param->parameter_number();
178 }
179
180 InitializeIrFunction(function_name);
181 // The rdtscp instruction is x86 specific. We will fallback to LLVM's generic
182 // readcyclecounter if it is unavailable.
183 bool use_rdtscp = arch_type_ == llvm::Triple::ArchType::x86 ||
184 arch_type_ == llvm::Triple::ArchType::x86_64;
185 profiling_state_ = ProfilingState(use_rdtscp);
186
187 tracing_state_.set_enabled(
188 computation->parent()->config().cpu_traceme_enabled());
189
190 TF_RETURN_IF_ERROR(computation->AcceptOrdered(this, instruction_order));
191 llvm::Function* ir_function = compute_function_->function();
192 InsertOrDie(&emitted_functions_, computation, ir_function);
193 // Delete 'compute_function', finalizing 'ir_function' and restoring caller
194 // IR insert point.
195
196 // Function epilogue: copying the value over to either the return register,
197 // or values pointing from the return register.
198 const BufferAllocation* root_allocation =
199 computation_root_allocation_.allocation();
200 if (root_allocation && root_allocation->is_thread_local()) {
201 EmitThreadLocalFunctionEpilogue(computation);
202 }
203
204 // Destructor for compute_function_ emits the "ret void" instruction.
205 compute_function_.reset();
206 computation_root_allocation_ = BufferAllocation::Slice();
207 computation_parameter_allocations_.clear();
208 return ir_function;
209 }
210
InitializeIrFunction(const string & function_name)211 void IrEmitter::InitializeIrFunction(const string& function_name) {
212 // Functions with local linkage get an inlining bonus. Because we know
213 // a-priori that embedded functions (non-entry functions) will not have its
214 // name resolved, give it local linkage.
215 llvm::Function::LinkageTypes linkage =
216 is_top_level_computation_ ? llvm::GlobalValue::ExternalLinkage
217 : llvm::GlobalValue::InternalLinkage;
218 // Create and initialize new IrFunction.
219 compute_function_.reset(new IrFunction(function_name, linkage,
220 hlo_module_config_, module_, &b_,
221 num_dynamic_loop_bounds_));
222 }
223
~IrEmitter()224 IrEmitter::~IrEmitter() {}
225
HandleBitcast(HloInstruction * bitcast)226 Status IrEmitter::HandleBitcast(HloInstruction* bitcast) {
227 VLOG(2) << "HandleBitcast: " << bitcast->ToString();
228 emitted_value_[bitcast] =
229 BitCast(GetEmittedValueFor(bitcast->operand(0)),
230 IrShapeType(bitcast->shape())->getPointerTo(), IrName(bitcast));
231 return Status::OK();
232 }
233
EmitGlobalForLiteral(const Literal & literal)234 llvm::Constant* IrEmitter::EmitGlobalForLiteral(const Literal& literal) {
235 llvm::Constant* initializer =
236 llvm_ir::ConvertLiteralToIrConstant(literal, module_);
237 llvm::GlobalVariable* result_global = new llvm::GlobalVariable(
238 /*Module=*/*module_,
239 /*Type=*/initializer->getType(),
240 /*isConstant=*/true,
241 /*Linkage=*/llvm::GlobalValue::PrivateLinkage,
242 /*Initializer=*/initializer,
243 /*Name=*/"");
244 result_global->setAlignment(
245 llvm::Align(MinimumAlignmentForShape(literal.shape())));
246 result_global->setUnnamedAddr(llvm::GlobalVariable::UnnamedAddr::Global);
247 return llvm::ConstantExpr::getBitCast(
248 result_global, IrShapeType(literal.shape())->getPointerTo());
249 }
250
EmitConstantGlobals()251 Status IrEmitter::EmitConstantGlobals() {
252 for (const BufferAllocation& allocation : assignment_.Allocations()) {
253 if (!allocation.is_constant()) {
254 continue;
255 }
256
257 const Literal& literal = llvm_ir::LiteralForConstantAllocation(allocation);
258 llvm::Constant* global_for_const;
259 auto it = emitted_literals_.find(&literal);
260 if (it != emitted_literals_.end()) {
261 global_for_const = it->second;
262 } else {
263 global_for_const = EmitGlobalForLiteral(literal);
264 InsertOrDie(&emitted_literals_, &literal, global_for_const);
265 }
266
267 InsertOrDie(&constant_buffer_to_global_, allocation.index(),
268 global_for_const);
269 }
270
271 return Status::OK();
272 }
273
HandleConstant(HloInstruction * constant)274 Status IrEmitter::HandleConstant(HloInstruction* constant) {
275 VLOG(2) << "HandleConstant: " << constant->ToString();
276 // IrEmitter::EmitConstantGlobals has already taken care of emitting the body
277 // of the constant.
278 return EmitTargetAddressForOp(constant);
279 }
280
HandleCopy(HloInstruction * copy)281 Status IrEmitter::HandleCopy(HloInstruction* copy) {
282 if (copy->shape().IsTuple() ||
283 (copy->shape().IsArray() &&
284 LayoutUtil::Equal(copy->operand(0)->shape().layout(),
285 copy->shape().layout()))) {
286 // If the layouts are equal this is just a memcpy. kCopy shallow copies a
287 // tuple so just memcpy the top-level buffer for tuples.
288 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(copy));
289 return EmitMemcpy(*(copy->operand(0)), *copy);
290 } else if (copy->shape().IsArray()) {
291 // Use the elemental emitter for array shapes.
292 return DefaultAction(copy);
293 }
294 return Unimplemented("unsupported operand type %s for copy instruction",
295 PrimitiveType_Name(copy->shape().element_type()));
296 }
297
298 // Calculate the alignment of a buffer allocated for a given primitive type.
MinimumAlignmentForPrimitiveType(PrimitiveType primitive_type)299 int IrEmitter::MinimumAlignmentForPrimitiveType(PrimitiveType primitive_type) {
300 int64_t byte_size = ShapeUtil::ByteSizeOfPrimitiveType(primitive_type);
301 DCHECK_GE(byte_size, 0);
302 // Largest scalar is a complex128 so we don't need to worry about the
303 // int64->int truncation here.
304 DCHECK_LE(byte_size, 16);
305
306 // Allocations may be 8-byte aligned if part of a small block.
307 return std::min(int64{8}, byte_size);
308 }
309
ByteSizeOf(const Shape & shape) const310 int64 IrEmitter::ByteSizeOf(const Shape& shape) const {
311 return llvm_ir::ByteSizeOf(shape, module_->getDataLayout());
312 }
313
314 // Calculate the alignment of a buffer allocated for a given shape.
MinimumAlignmentForShape(const Shape & shape)315 int IrEmitter::MinimumAlignmentForShape(const Shape& shape) {
316 if (ShapeUtil::IsScalar(shape)) {
317 return MinimumAlignmentForPrimitiveType(shape.element_type());
318 }
319
320 int64_t buffer_size = ByteSizeOf(shape);
321 DCHECK_GE(buffer_size, 0);
322 DCHECK_LE(buffer_size, SIZE_MAX);
323
324 return target_machine_features_.minimum_alignment_for_allocation(buffer_size);
325 }
326
AttachAlignmentMetadataForLoad(llvm::LoadInst * load,const Shape & shape)327 void IrEmitter::AttachAlignmentMetadataForLoad(llvm::LoadInst* load,
328 const Shape& shape) {
329 int alignment = MinimumAlignmentForShape(shape);
330 if (alignment > 1) {
331 llvm_ir::SetAlignmentMetadataForLoad(load, alignment);
332 }
333 }
334
AttachAlignmentMetadataForLoad(llvm::LoadInst * load,int64_t buffer_size)335 void IrEmitter::AttachAlignmentMetadataForLoad(llvm::LoadInst* load,
336 int64_t buffer_size) {
337 int alignment =
338 target_machine_features_.minimum_alignment_for_allocation(buffer_size);
339 if (alignment > 1) {
340 llvm_ir::SetAlignmentMetadataForLoad(load, alignment);
341 }
342 }
343
AttachDereferenceableMetadataForLoad(llvm::LoadInst * load,const Shape & shape)344 void IrEmitter::AttachDereferenceableMetadataForLoad(llvm::LoadInst* load,
345 const Shape& shape) {
346 AttachDereferenceableMetadataForLoad(load, ByteSizeOf(shape));
347 }
348
AttachDereferenceableMetadataForLoad(llvm::LoadInst * load,int64_t buffer_size)349 void IrEmitter::AttachDereferenceableMetadataForLoad(llvm::LoadInst* load,
350 int64_t buffer_size) {
351 if (buffer_size > 0) {
352 llvm_ir::SetDereferenceableMetadataForLoad(load, buffer_size);
353 }
354 }
355
HandleGetTupleElement(HloInstruction * get_tuple_element)356 Status IrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element) {
357 // A tuple is an array of pointers, one for each operand. Each pointer points
358 // to the output buffer of its corresponding operand. A GetTupleElement
359 // instruction forwards a pointer to the tuple element buffer at the given
360 // index.
361 const HloInstruction* operand = get_tuple_element->operand(0);
362 const Shape& shape = get_tuple_element->shape();
363 emitted_value_[get_tuple_element] = llvm_ir::EmitGetTupleElement(
364 shape, get_tuple_element->tuple_index(), MinimumAlignmentForShape(shape),
365 GetEmittedValueFor(operand), &b_);
366 return Status::OK();
367 }
368
HandleSelect(HloInstruction * select)369 Status IrEmitter::HandleSelect(HloInstruction* select) {
370 auto pred = select->operand(0);
371 TF_RET_CHECK(pred->shape().element_type() == PRED);
372 return DefaultAction(select);
373 }
374
HandleTupleSelect(HloInstruction * tuple_select)375 Status IrEmitter::HandleTupleSelect(HloInstruction* tuple_select) {
376 auto pred = tuple_select->operand(0);
377 auto on_true = tuple_select->operand(1);
378 auto on_false = tuple_select->operand(2);
379 TF_RET_CHECK(pred->shape().element_type() == PRED);
380 TF_RET_CHECK(ShapeUtil::IsScalar(pred->shape()));
381 TF_RET_CHECK(tuple_select->shape().IsTuple());
382 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(tuple_select));
383 llvm_ir::EmitTupleSelect(GetIrArrayFor(tuple_select), GetIrArrayFor(pred),
384 GetEmittedValueFor(on_true),
385 GetEmittedValueFor(on_false), &b_);
386 return Status::OK();
387 }
388
HandleInfeed(HloInstruction * instruction)389 Status IrEmitter::HandleInfeed(HloInstruction* instruction) {
390 HloInfeedInstruction* infeed = Cast<HloInfeedInstruction>(instruction);
391 VLOG(2) << "HandleInfeed: " << infeed->ToString();
392
393 // The infeed operation produces a two-element tuple containing data and a
394 // token value. HloInfeedInstruction::infeed_shape gives us the data shape.
395 const Shape& data_shape = infeed->infeed_shape();
396 DCHECK(ShapeUtil::Equal(data_shape,
397 ShapeUtil::GetTupleElementShape(infeed->shape(), 0)));
398 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(infeed));
399
400 // Write the tuple index table.
401 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice data_slice,
402 assignment_.GetUniqueSlice(infeed, {0}));
403 llvm::Value* data_address = EmitBufferPointer(data_slice, data_shape);
404 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice token_slice,
405 assignment_.GetUniqueSlice(infeed, {1}));
406 llvm::Value* token_address = EmitBufferPointer(
407 token_slice, ShapeUtil::GetTupleElementShape(infeed->shape(), 1));
408 llvm_ir::EmitTuple(GetIrArrayFor(infeed), {data_address, token_address}, &b_);
409
410 if (data_shape.IsTuple()) {
411 TF_RET_CHECK(!ShapeUtil::IsNestedTuple(data_shape));
412
413 // For a tuple, we first copy each of the internal elements to
414 // their corresponding target locations. We then construct the
415 // tuple outer buffer containing pointers to the internal
416 // elements.
417 std::vector<llvm::Value*> tuple_element_addresses;
418 for (int64_t i = 0; i < data_shape.tuple_shapes_size(); ++i) {
419 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice buffer,
420 assignment_.GetUniqueSlice(infeed, {0, i}));
421
422 const Shape& tuple_element_shape =
423 ShapeUtil::GetTupleElementShape(data_shape, i);
424
425 // Only the outer tuple buffer's target address is obtained from
426 // GetEmittedValueFor, to handle the case when Infeed is the root
427 // instruction. Target addresses for internal elements can be obtained
428 // from EmitBufferPointer.
429 llvm::Value* tuple_element_address =
430 EmitBufferPointer(buffer, tuple_element_shape);
431
432 TF_RETURN_IF_ERROR(EmitXfeedTransfer(
433 XfeedKind::kInfeed, tuple_element_shape, tuple_element_address));
434
435 tuple_element_addresses.push_back(tuple_element_address);
436 }
437
438 llvm_ir::EmitTuple(llvm_ir::IrArray(data_address, data_shape),
439 tuple_element_addresses, &b_);
440 } else {
441 TF_RETURN_IF_ERROR(
442 EmitXfeedTransfer(XfeedKind::kInfeed, data_shape, data_address));
443 }
444
445 return Status::OK();
446 }
447
EmitXfeedTransfer(XfeedKind kind,const Shape & shape,llvm::Value * program_buffer_address)448 Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape,
449 llvm::Value* program_buffer_address) {
450 int64_t length = ByteSizeOf(shape);
451 if (length < 0 || length > std::numeric_limits<int32>::max()) {
452 return InvalidArgument(
453 "xfeed (infeed or outfeed) buffer length %d is outside the valid "
454 "size range",
455 length);
456 }
457 int32_t length_32 = static_cast<int32>(length);
458
459 int32_t shape_length;
460 TF_ASSIGN_OR_RETURN(
461 llvm::Value * shape_ptr,
462 llvm_ir::EncodeSelfDescribingShapeConstant(shape, &shape_length, &b_));
463
464 llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext());
465
466 const char* acquire_func_name =
467 kind == XfeedKind::kInfeed
468 ? runtime::kAcquireInfeedBufferForDequeueSymbolName
469 : runtime::kAcquireOutfeedBufferForPopulationSymbolName;
470
471 // Implementation note: this call informs the runtime that it wants a
472 // buffer of size exactly 'length_32', and the runtime is responsible for
473 // check-failing the process if there is a mismatch, versus passing us
474 // back a buffer that we might overrun.
475 llvm::Value* acquired_pointer =
476 EmitCallToFunc(acquire_func_name,
477 {GetExecutableRunOptionsArgument(), b_.getInt32(length_32),
478 shape_ptr, b_.getInt32(shape_length)},
479 i8_ptr_type);
480 if (kind == XfeedKind::kInfeed) {
481 // Copy to the program buffer address from the acquired buffer.
482 MemCpy(program_buffer_address, /*DstAlign=*/llvm::Align(1),
483 acquired_pointer,
484 /*SrcAlign=*/llvm::Align(1), length_32);
485 } else {
486 // Outfeed -- copy from the in-program address to the acquired buffer.
487 MemCpy(acquired_pointer, /*DstAlign=*/llvm::Align(1),
488 program_buffer_address,
489 /*SrcAlign=*/llvm::Align(1), length_32);
490 if (emit_code_for_msan_) {
491 // Mark the outfed data as initialized for msan. The buffer gets read by
492 // the host code, which might be msan-instrumented.
493 // TODO(b/66051036): Run the msan instrumentation pass instead.
494 const llvm::DataLayout& dl = module_->getDataLayout();
495 llvm::Type* intptr_type = b_.getIntPtrTy(dl);
496 EmitCallToFunc(
497 "__msan_unpoison",
498 {acquired_pointer, llvm::ConstantInt::get(intptr_type, length)},
499 b_.getVoidTy());
500 }
501 }
502
503 const char* release_func_name =
504 kind == XfeedKind::kInfeed
505 ? runtime::kReleaseInfeedBufferAfterDequeueSymbolName
506 : runtime::kReleaseOutfeedBufferAfterPopulationSymbolName;
507 EmitCallToFunc(release_func_name,
508 {GetExecutableRunOptionsArgument(), b_.getInt32(length_32),
509 acquired_pointer, shape_ptr, b_.getInt32(shape_length)},
510 b_.getVoidTy());
511
512 return Status::OK();
513 }
514
HandleOutfeed(HloInstruction * outfeed)515 Status IrEmitter::HandleOutfeed(HloInstruction* outfeed) {
516 // Outfeed produces no useful result, but it does return a token[] that can be
517 // threaded through to other side effecting operations to ensure ordering. In
518 // the IR emitter we treat this token as a normal u8[] and thus need to insert
519 // an entry for it in emitted_value_.
520 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(outfeed));
521
522 HloInstruction* operand = outfeed->operands()[0];
523 const Shape& operand_shape = operand->shape();
524
525 llvm::Value* value = GetEmittedValueFor(operand);
526 if (!operand_shape.IsTuple()) {
527 return EmitXfeedTransfer(XfeedKind::kOutfeed, operand_shape, value);
528 }
529
530 TF_RET_CHECK(!ShapeUtil::IsNestedTuple(operand_shape));
531
532 for (int64_t i = 0; i < operand_shape.tuple_shapes_size(); ++i) {
533 const Shape& tuple_element_shape =
534 ShapeUtil::GetTupleElementShape(operand_shape, i);
535 llvm::Value* tuple_element = llvm_ir::EmitGetTupleElement(
536 tuple_element_shape, i, MinimumAlignmentForShape(tuple_element_shape),
537 value, &b_);
538 TF_RETURN_IF_ERROR(EmitXfeedTransfer(XfeedKind::kOutfeed,
539 tuple_element_shape, tuple_element));
540 }
541
542 return Status::OK();
543 }
544
HandleSort(HloInstruction * hlo)545 Status IrEmitter::HandleSort(HloInstruction* hlo) {
546 const HloSortInstruction* sort = Cast<HloSortInstruction>(hlo);
547 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(sort));
548 Shape keys_shape = sort->keys()->shape();
549 PrimitiveType keys_type = keys_shape.element_type();
550 if (!primitive_util::IsArrayType(keys_type)) {
551 return Unimplemented("Element type %s not supported in the Sort op on CPU.",
552 PrimitiveType_Name(keys_type));
553 }
554 std::vector<llvm::Value*> destination_addresses(sort->operand_count());
555 for (int64_t i = 0; i < sort->operand_count(); ++i) {
556 ShapeIndex shape_index =
557 sort->values_count() > 0 ? ShapeIndex({i}) : ShapeIndex({});
558 const HloInstruction* operand = sort->operand(i);
559 // We assume that the layout of all involved operands and outputs is the
560 // same.
561 TF_RET_CHECK(
562 LayoutUtil::LayoutsInShapesEqual(keys_shape, operand->shape()));
563 TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual(
564 keys_shape, ShapeUtil::GetSubshape(sort->shape(), shape_index)));
565
566 // The sort is implemented in-place, therefore we first copy the operand
567 // buffer to the output buffer if they are not the same.
568 auto destination_buffer = GetAllocationSlice(*sort, shape_index);
569 destination_addresses[i] =
570 EmitBufferPointer(destination_buffer, operand->shape());
571 auto source_address = GetAllocationSlice(*operand);
572 if (destination_buffer != source_address) {
573 int64_t primitive_type_size =
574 ShapeUtil::ByteSizeOfPrimitiveType(operand->shape().element_type());
575 auto source_buffer = GetEmittedValueFor(operand);
576 int64_t size = ByteSizeOf(operand->shape());
577 MemCpy(destination_addresses[i],
578 /*DstAlign=*/llvm::Align(primitive_type_size), source_buffer,
579 /*SrcAlign=*/llvm::Align(primitive_type_size), size);
580 }
581 }
582
583 // Normalize the shape and the dimension to sort.
584 Shape normalized_keys_shape =
585 ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(keys_shape);
586 int64_t physical_dimension_to_sort = LayoutUtil::MakeLogicalToPhysical(
587 keys_shape.layout())[sort->sort_dimension()];
588
589 int64_t sort_dimension_elements =
590 normalized_keys_shape.dimensions(physical_dimension_to_sort);
591 int64_t higher_dimensions = 1;
592 for (int64_t i = 0; i < physical_dimension_to_sort; ++i) {
593 higher_dimensions *= normalized_keys_shape.dimensions(i);
594 }
595 int64_t lower_dimensions = 1;
596 for (int64_t i = normalized_keys_shape.rank() - 1;
597 i > physical_dimension_to_sort; --i) {
598 lower_dimensions *= normalized_keys_shape.dimensions(i);
599 }
600
601 CHECK(absl::c_binary_search(thread_local_computations_, sort->to_apply()));
602 llvm::Value* values = llvm_ir::EmitAllocaAtFunctionEntryWithCount(
603 b_.getInt8PtrTy(), b_.getInt32(sort->operand_count()), "cc_values_alloca",
604 &b_);
605 llvm::Value* sizes = llvm_ir::EmitAllocaAtFunctionEntryWithCount(
606 b_.getInt32Ty(), b_.getInt32(sort->operand_count()), "cc_sizes_alloca",
607 &b_);
608 for (int64_t i = 0; i < sort->operand_count(); ++i) {
609 llvm::Value* value_as_i8ptr =
610 PointerCast(destination_addresses[i], b_.getInt8PtrTy());
611 llvm::Value* slot_in_values_alloca =
612 ConstInBoundsGEP1_32(b_.getInt8PtrTy(), values, i);
613 Store(value_as_i8ptr, slot_in_values_alloca);
614 llvm::Value* slot_in_sizes_alloca =
615 ConstInBoundsGEP1_32(b_.getInt32Ty(), sizes, i);
616 llvm::Value* size = b_.getInt32(ShapeUtil::ByteSizeOfPrimitiveType(
617 sort->operand(i)->shape().element_type()));
618 Store(size, slot_in_sizes_alloca);
619 }
620
621 auto less_than_function = FindOrDie(emitted_functions_, sort->to_apply());
622 EmitCallToFunc(
623 runtime::kKeyValueSortSymbolName,
624 {b_.getInt64(higher_dimensions), b_.getInt64(sort_dimension_elements),
625 b_.getInt64(lower_dimensions), values,
626 b_.getInt32(sort->operand_count()), sizes, b_.getInt1(sort->is_stable()),
627 GetExecutableRunOptionsArgument(), GetProfileCountersArgument(),
628 less_than_function},
629 b_.getVoidTy());
630
631 if (sort->values_count() > 0) {
632 llvm_ir::EmitTuple(GetIrArrayFor(sort), destination_addresses, &b_);
633 }
634 return Status::OK();
635 }
636
HandleTuple(HloInstruction * tuple)637 Status IrEmitter::HandleTuple(HloInstruction* tuple) {
638 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(tuple));
639 std::vector<llvm::Value*> base_ptrs;
640 for (auto operand : tuple->operands()) {
641 base_ptrs.push_back(GetEmittedValueFor(operand));
642 }
643 llvm_ir::EmitTuple(GetIrArrayFor(tuple), base_ptrs, &b_);
644 return Status::OK();
645 }
646
HandleReduceWindow(HloInstruction * reduce_window)647 Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) {
648 // Pseudo code for reduce window:
649 //
650 // for (coordinates O in the output)
651 // value = init_value;
652 // for (coordinates W in the window)
653 // for each index i:
654 // input coordinates I_i = O_i * stride_i + W_i - pad_low_i
655 // if I within bounds of input:
656 // value = function(value, input(I));
657 // output(O) = value;
658 //
659 // This is completely un-optimized and just here to have something
660 // that works.
661 return DefaultAction(reduce_window);
662 }
663
HandleSelectAndScatter(HloInstruction * select_and_scatter)664 Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) {
665 CHECK_EQ(select_and_scatter->operand_count(), 3);
666 const auto operand = select_and_scatter->operand(0);
667 const auto source = select_and_scatter->operand(1);
668 const auto init_value = select_and_scatter->operand(2);
669 const Window& window = select_and_scatter->window();
670 PrimitiveType operand_element_type = operand->shape().element_type();
671 const int64_t rank = operand->shape().rank();
672 CHECK_EQ(rank, source->shape().rank());
673 CHECK_EQ(rank, window.dimensions_size());
674
675 // TODO(b/31410564): Implement dilation for select-and-scatter.
676 if (window_util::HasDilation(window)) {
677 return Unimplemented(
678 "Dilation for SelectAndScatter is not implemented on CPU. ");
679 }
680
681 // Pseudo code for select-and-scatter:
682 //
683 // initialized_flag is initially off for every window, and is turned on after
684 // the first iteration is completed and the first operand value is selected.
685 //
686 // output(*) = init_value
687 // for (coordinates S in the source) {
688 // initialized_flag = false
689 // for (coordinates W in the window) {
690 // I = S * stride + W - pad_low
691 // if I within bounds of operand:
692 // if !initialized_flag or select(selected_value, operand(I)) == false:
693 // selected_value = operand(I)
694 // selected_index = I
695 // initialized_flag = true
696 // }
697 // output(selected_index) = scatter(output(selected_index), source(S))
698 // }
699 //
700
701 // Initialize the output array with the given init_value.
702 TF_RETURN_IF_ERROR(EmitTargetElementLoop(
703 select_and_scatter, /*desc=*/IrName(select_and_scatter, "init"),
704 [this, init_value](const llvm_ir::IrArray::Index& target_index) {
705 llvm::Value* init_value_addr = GetEmittedValueFor(init_value);
706 return Load(init_value_addr);
707 }));
708
709 // Create a loop to iterate over the source array to scatter to the output.
710 llvm_ir::ForLoopNest source_loops(IrName(select_and_scatter), &b_);
711 const llvm_ir::IrArray::Index source_index =
712 source_loops.AddLoopsForShape(source->shape(), "source");
713 SetToFirstInsertPoint(source_loops.GetInnerLoopBodyBasicBlock(), &b_);
714
715 // Allocate space to keep the currently selected value, its index, and
716 // the boolean initialized_flag, which is initially set to false.
717 llvm::Value* selected_value_address = llvm_ir::EmitAllocaAtFunctionEntry(
718 llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_),
719 "selected_value_address", &b_,
720 MinimumAlignmentForPrimitiveType(operand_element_type));
721 llvm::Value* selected_index_address =
722 llvm_ir::EmitAllocaAtFunctionEntryWithCount(
723 b_.getInt64Ty(), b_.getInt32(rank), "selected_index_address", &b_);
724 llvm::Value* initialized_flag_address = llvm_ir::EmitAllocaAtFunctionEntry(
725 b_.getInt1Ty(), "initialized_flag_address", &b_);
726 Store(b_.getInt1(false), initialized_flag_address);
727
728 // Create the inner loop to iterate over the window.
729 llvm_ir::ForLoopNest window_loops(IrName(select_and_scatter, "window"), &b_);
730 std::vector<int64> window_size;
731 for (const auto& dim : window.dimensions()) {
732 window_size.push_back(dim.size());
733 }
734 const llvm_ir::IrArray::Index window_index = window_loops.AddLoopsForShape(
735 ShapeUtil::MakeShape(operand_element_type, window_size), "window");
736 SetToFirstInsertPoint(window_loops.GetInnerLoopBodyBasicBlock(), &b_);
737
738 // Compute the operand index to visit and evaluate the condition whether the
739 // operand index is within the bounds. The unsigned comparison includes
740 // checking whether the operand index >= 0.
741 std::vector<llvm::Value*> operand_multi_index(source_index.size());
742 llvm::Value* in_bounds_condition = b_.getTrue();
743 for (int64_t i = 0; i < rank; ++i) {
744 llvm::Value* strided_index =
745 NSWMul(source_index[i], b_.getInt64(window.dimensions(i).stride()));
746 operand_multi_index[i] =
747 NSWSub(NSWAdd(strided_index, window_index[i]),
748 b_.getInt64(window.dimensions(i).padding_low()));
749 llvm::Value* index_condition =
750 ICmpULT(operand_multi_index[i],
751 b_.getInt64(ShapeUtil::GetDimension(operand->shape(), i)));
752 in_bounds_condition = And(in_bounds_condition, index_condition);
753 }
754 CHECK(in_bounds_condition != nullptr);
755
756 // Only need to do something if the operand index is within the bounds. First
757 // check if the initialized_flag is set.
758 llvm_ir::LlvmIfData if_in_bounds =
759 llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &b_);
760 SetToFirstInsertPoint(if_in_bounds.true_block, &b_);
761 llvm_ir::LlvmIfData if_initialized = llvm_ir::EmitIfThenElse(
762 Load(initialized_flag_address), "initialized", &b_);
763
764 // If the initialized_flag is false, initialize the selected value and index
765 // with the currently visiting operand.
766 SetToFirstInsertPoint(if_initialized.false_block, &b_);
767 const auto save_operand_index =
768 [&](const llvm_ir::IrArray::Index& operand_index) {
769 for (int64_t i = 0; i < rank; ++i) {
770 llvm::Value* selected_index_address_slot =
771 InBoundsGEP(selected_index_address, {b_.getInt32(i)});
772 Store(operand_index[i], selected_index_address_slot);
773 }
774 };
775 llvm_ir::IrArray operand_array(GetIrArrayFor(operand));
776 llvm_ir::IrArray::Index operand_index(
777 operand_multi_index, operand_array.GetShape(), b_.getInt64Ty());
778 llvm::Value* operand_data =
779 operand_array.EmitReadArrayElement(operand_index, &b_);
780 Store(operand_data, selected_value_address);
781 save_operand_index(operand_index);
782 Store(b_.getInt1(true), initialized_flag_address);
783
784 // If the initialized_flag is true, call the `select` function to potentially
785 // update the selected value and index with the currently visiting operand.
786 SetToFirstInsertPoint(if_initialized.true_block, &b_);
787 llvm::Value* operand_address =
788 operand_array.EmitArrayElementAddress(operand_index, &b_);
789 llvm::Value* operand_element = Load(operand_address);
790 llvm::Value* result = EmitScalarReturningThreadLocalCall(
791 *select_and_scatter->select(),
792 {Load(selected_value_address), operand_element}, "select_function");
793
794 // If the 'select' function returns false, update the selected value and the
795 // index to the currently visiting operand.
796 llvm::Value* cond = ICmpNE(
797 result,
798 llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0),
799 "boolean_predicate");
800 llvm_ir::LlvmIfData if_select_lhs =
801 llvm_ir::EmitIfThenElse(cond, "if-select-lhs", &b_);
802 SetToFirstInsertPoint(if_select_lhs.false_block, &b_);
803 Store(Load(operand_address), selected_value_address);
804 save_operand_index(operand_index);
805
806 // After iterating over the window elements, scatter the source element to
807 // the selected index of the output. The value we store at the output
808 // location is computed by calling the `scatter` function with the source
809 // value and the current output value.
810 SetToFirstInsertPoint(window_loops.GetOuterLoopExitBasicBlock(), &b_);
811 std::vector<llvm::Value*> selected_multi_index;
812 for (int64_t i = 0; i < rank; ++i) {
813 llvm::Value* selected_index_address_slot =
814 InBoundsGEP(selected_index_address, {b_.getInt32(i)});
815 selected_multi_index.push_back(Load(selected_index_address_slot));
816 }
817 llvm_ir::IrArray source_array(GetIrArrayFor(source));
818 llvm::Value* source_value =
819 source_array.EmitReadArrayElement(source_index, &b_);
820 llvm_ir::IrArray output_array(GetIrArrayFor(select_and_scatter));
821 llvm_ir::IrArray::Index selected_index(
822 selected_multi_index, output_array.GetShape(), source_index.GetType());
823 llvm::Value* output_value =
824 output_array.EmitReadArrayElement(selected_index, &b_);
825 llvm::Value* scatter_value = EmitScalarReturningThreadLocalCall(
826 *select_and_scatter->scatter(), {output_value, source_value},
827 "scatter_function");
828 output_array.EmitWriteArrayElement(selected_index, scatter_value, &b_);
829
830 SetToFirstInsertPoint(source_loops.GetOuterLoopExitBasicBlock(), &b_);
831 return Status::OK();
832 }
833
HandleDot(HloInstruction * dot)834 Status IrEmitter::HandleDot(HloInstruction* dot) {
835 auto lhs = dot->operand(0);
836 auto rhs = dot->operand(1);
837 TF_RETURN_IF_ERROR(ElementTypesSameAndSupported(
838 /*instruction=*/*dot, /*operands=*/{lhs, rhs},
839 /*supported_types=*/
840 {PRED, S8, U8, S16, U16, S32, U32, S64, U64, F16, F32, F64, C64, C128}));
841 const DotDimensionNumbers& dnums = dot->dot_dimension_numbers();
842
843 if (dnums.lhs_contracting_dimensions_size() != 1) {
844 // This is disallowed by ShapeInference today.
845 return Unimplemented(
846 "Dot with multiple contracting dimensions not implemented.");
847 }
848
849 llvm_ir::IrArray lhs_array(GetIrArrayFor(lhs));
850 llvm_ir::IrArray rhs_array(GetIrArrayFor(rhs));
851
852 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(dot));
853 llvm_ir::IrArray target_array = GetIrArrayFor(dot);
854
855 VLOG(2) << "HandleDot: ";
856 VLOG(2) << " lhs operand: "
857 << llvm_ir::DumpToString(*lhs_array.GetBasePointer());
858 VLOG(2) << " rhs operand: "
859 << llvm_ir::DumpToString(*rhs_array.GetBasePointer());
860 VLOG(2) << " target: "
861 << llvm_ir::DumpToString(*target_array.GetBasePointer());
862
863 // Dot operation is complicated so we delegate to a helper class.
864 return EmitDotOperation(*dot, target_array, lhs_array, rhs_array,
865 /*addend_array=*/nullptr,
866 GetExecutableRunOptionsArgument(), &b_, mlir_context_,
867 hlo_module_config_, target_machine_features_);
868 }
869
HandleConvolution(HloInstruction * convolution)870 Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
871 auto lhs = convolution->operand(0);
872 auto rhs = convolution->operand(1);
873 TF_RETURN_IF_ERROR(ElementTypesSameAndSupported(
874 /*instruction=*/*convolution, /*operands=*/{lhs, rhs},
875 /*supported_types=*/
876 {PRED, S8, U8, S16, U16, S32, U32, S64, U64, F16, F32, F64, C64, C128}));
877
878 // TODO(tonywy): Add PotentiallyImplementedAsMKLConvolution to support
879 // different data layouts.
880 if (PotentiallyImplementedAsEigenConvolution(*convolution,
881 target_machine_features_)) {
882 const Shape& lhs_shape = lhs->shape();
883 const Shape& rhs_shape = rhs->shape();
884 const Shape& convolution_shape = convolution->shape();
885 // The input, kernel and output agree with respect to layout.
886 if (LayoutUtil::IsMonotonicWithDim0Major(lhs_shape.layout()) &&
887 LayoutUtil::IsMonotonicWithDim0Major(rhs_shape.layout()) &&
888 LayoutUtil::IsMonotonicWithDim0Major(convolution_shape.layout())) {
889 // We lower 1D convolutions into calls to the same Eigen function as 2D
890 // convolutions, except that we pretend that the 1D convolution is really
891 // a 2D convolution with the missing dimension set to 1. We also adjust
892 // the padding, dilation parameters as needed.
893 bool one_dim_convolution = lhs_shape.dimensions_size() == 3;
894 llvm::Value* lhs_address = GetEmittedValueFor(lhs);
895 llvm::Value* rhs_address = GetEmittedValueFor(rhs);
896 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(convolution));
897
898 const ConvolutionDimensionNumbers& dnums =
899 convolution->convolution_dimension_numbers();
900
901 // Input tensor.
902 const Shape& input_shape = convolution->operand(0)->shape();
903 int64_t input_batch =
904 input_shape.dimensions(dnums.input_batch_dimension());
905 int64_t input_rows =
906 input_shape.dimensions(dnums.input_spatial_dimensions(0));
907 int64_t input_cols =
908 one_dim_convolution
909 ? 1
910 : input_shape.dimensions(dnums.input_spatial_dimensions(1));
911 int64_t input_channels =
912 input_shape.dimensions(dnums.input_feature_dimension());
913
914 // Kernel tensor.
915 const Shape& kernel_shape = convolution->operand(1)->shape();
916 int64_t kernel_rows =
917 kernel_shape.dimensions(dnums.kernel_spatial_dimensions(0));
918 int64_t kernel_cols =
919 one_dim_convolution
920 ? 1
921 : kernel_shape.dimensions(dnums.kernel_spatial_dimensions(1));
922 int64_t kernel_channels =
923 kernel_shape.dimensions(dnums.kernel_input_feature_dimension());
924 int64_t kernel_filters =
925 kernel_shape.dimensions(dnums.kernel_output_feature_dimension());
926
927 // Output tensor.
928 const Shape& convolution_shape = convolution->shape();
929 int64_t output_rows =
930 convolution_shape.dimensions(dnums.output_spatial_dimensions(0));
931 int64_t output_cols = one_dim_convolution
932 ? 1
933 : convolution_shape.dimensions(
934 dnums.output_spatial_dimensions(1));
935
936 // Extract the window stride for the convolution.
937 const Window& window = convolution->window();
938 int64_t row_stride = window.dimensions(0).stride();
939 int64_t col_stride =
940 one_dim_convolution ? 1 : window.dimensions(1).stride();
941
942 int64_t padding_top = window.dimensions(0).padding_low();
943 int64_t padding_bottom = window.dimensions(0).padding_high();
944 int64_t padding_left =
945 one_dim_convolution ? 0 : window.dimensions(1).padding_low();
946 int64_t padding_right =
947 one_dim_convolution ? 0 : window.dimensions(1).padding_high();
948
949 int64_t lhs_row_dilation = window.dimensions(0).base_dilation();
950 int64_t lhs_col_dilation =
951 one_dim_convolution ? 1 : window.dimensions(1).base_dilation();
952 int64_t rhs_row_dilation = window.dimensions(0).window_dilation();
953 int64_t rhs_col_dilation =
954 one_dim_convolution ? 1 : window.dimensions(1).window_dilation();
955
956 PrimitiveType primitive_type = lhs->shape().element_type();
957 llvm::Type* ir_ptr_type = primitive_type == F16
958 ? b_.getHalfTy()->getPointerTo()
959 : b_.getFloatTy()->getPointerTo();
960 bool multi_threaded =
961 hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen();
962 bool use_mkl_dnn =
963 hlo_module_config_.debug_options().xla_cpu_use_mkl_dnn();
964
965 // TODO(b/78639006) Singlethread MKL conv2d is not implemented due to the
966 // potential race condition by setting the omp_num_threads.
967 const char* fn_name =
968 primitive_type == F16
969 ? (multi_threaded
970 ? runtime::kEigenConvF16SymbolName
971 : runtime::kEigenSingleThreadedConvF16SymbolName)
972 : (multi_threaded
973 ? (use_mkl_dnn ? runtime::kMKLConvF32SymbolName
974 : runtime::kEigenConvF32SymbolName)
975 : runtime::kEigenSingleThreadedConvF32SymbolName);
976 if (!multi_threaded && use_mkl_dnn) {
977 LOG(WARNING) << "Using Eigen instead of MKL-DNN for single-threaded "
978 "conv2d function.";
979 }
980 EmitCallToFunc(fn_name,
981 {
982 GetExecutableRunOptionsArgument(),
983 BitCast(GetEmittedValueFor(convolution), ir_ptr_type),
984 BitCast(lhs_address, ir_ptr_type),
985 BitCast(rhs_address, ir_ptr_type),
986 b_.getInt64(input_batch),
987 b_.getInt64(input_rows),
988 b_.getInt64(input_cols),
989 b_.getInt64(input_channels),
990 b_.getInt64(kernel_rows),
991 b_.getInt64(kernel_cols),
992 b_.getInt64(kernel_channels),
993 b_.getInt64(kernel_filters),
994 b_.getInt64(output_rows),
995 b_.getInt64(output_cols),
996 b_.getInt64(row_stride),
997 b_.getInt64(col_stride),
998 b_.getInt64(padding_top),
999 b_.getInt64(padding_bottom),
1000 b_.getInt64(padding_left),
1001 b_.getInt64(padding_right),
1002 b_.getInt64(lhs_row_dilation),
1003 b_.getInt64(lhs_col_dilation),
1004 b_.getInt64(rhs_row_dilation),
1005 b_.getInt64(rhs_col_dilation),
1006 },
1007 b_.getVoidTy(), /*does_not_throw=*/true,
1008 /*only_accesses_arg_memory=*/true);
1009
1010 return Status::OK();
1011 }
1012 }
1013
1014 // This is a completely un-optimized version of convolution just to
1015 // have an early version that works. E.g. the input index and
1016 // padding calculation is not hoisted out of the inner loop.
1017 //
1018 // See the description of convolution in the XLA documentation for the pseudo
1019 // code for convolution.
1020 return DefaultAction(convolution);
1021 }
1022
HandleFft(HloInstruction * fft)1023 Status IrEmitter::HandleFft(HloInstruction* fft) {
1024 auto operand = fft->operand(0);
1025 TF_RETURN_IF_ERROR(ElementTypesSameAndSupported(
1026 /*instruction=*/*fft, /*operands=*/{operand},
1027 /*supported_types=*/{F32, F64, C64, C128}));
1028 TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(operand->shape().layout()));
1029 TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(fft->shape().layout()));
1030 VLOG(3) << "operand=" << ShapeUtil::HumanStringWithLayout(operand->shape());
1031 VLOG(3) << "fft=" << ShapeUtil::HumanStringWithLayout(fft->shape());
1032
1033 llvm::Value* operand_address = GetEmittedValueFor(operand);
1034 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(fft));
1035
1036 const std::vector<int64>& fft_length = fft->fft_length();
1037 int64_t input_batch = 1;
1038 for (int i = 0; i < fft->shape().dimensions_size() - fft_length.size(); i++) {
1039 input_batch *= fft->shape().dimensions(i);
1040 }
1041
1042 // Args have been computed, make the call.
1043 llvm::Type* int8_ptr_type = b_.getInt8Ty()->getPointerTo();
1044 bool multi_threaded_eigen =
1045 hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen();
1046 const char* fn_name = multi_threaded_eigen
1047 ? runtime::kEigenFftSymbolName
1048 : runtime::kEigenSingleThreadedFftSymbolName;
1049 const int fft_rank = fft_length.size();
1050 EmitCallToFunc(
1051 fn_name,
1052 {GetExecutableRunOptionsArgument(),
1053 BitCast(GetEmittedValueFor(fft), int8_ptr_type),
1054 BitCast(operand_address, int8_ptr_type), b_.getInt32(fft->fft_type()),
1055 b_.getInt32(operand->shape().element_type() == F64 ||
1056 operand->shape().element_type() == C128),
1057 b_.getInt32(fft_rank), b_.getInt64(input_batch),
1058 b_.getInt64(fft_rank > 0 ? fft_length[0] : 0),
1059 b_.getInt64(fft_rank > 1 ? fft_length[1] : 0),
1060 b_.getInt64(fft_rank > 2 ? fft_length[2] : 0)},
1061 b_.getVoidTy(), /*does_not_throw=*/true,
1062 /*only_accesses_arg_memory=*/false,
1063 /*only_accesses_inaccessible_mem_or_arg_mem=*/true);
1064
1065 return Status::OK();
1066 }
1067
HandleAllReduceSingleReplica(HloInstruction * crs)1068 Status IrEmitter::HandleAllReduceSingleReplica(HloInstruction* crs) {
1069 // When there is a single replica, a cross replica sum is the identity
1070 // function, and the buffer assignment expects a copy.
1071 //
1072 // TODO(b/80100934): We would like to eliminate one-replica CRS nodes entirely
1073 // in algebraic-simplifier, but currently on some platforms
1074 // HloModuleConfig::num_replicas changes between when the module is compiled
1075 // and when it's run.
1076 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(crs));
1077
1078 // CRS with one operand and one replica is simply the identity function.
1079 if (crs->operand_count() == 1) {
1080 return EmitMemcpy(*crs->operand(0), *crs);
1081 }
1082
1083 // CRS with multiple operands and one replica produces a (one-deep) tuple.
1084 std::vector<llvm::Value*> operand_ptrs;
1085 for (int64_t i = 0; i < crs->operand_count(); ++i) {
1086 llvm::Value* in_ptr = GetEmittedValueFor(crs->operand(i));
1087 TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice out_slice,
1088 assignment_.GetUniqueSlice(crs, {i}));
1089
1090 const Shape& operand_shape = crs->operand(i)->shape();
1091 CHECK(operand_shape.IsArray())
1092 << "Operands to all-reduce must be arrays: " << crs->ToString();
1093 operand_ptrs.push_back(EmitBufferPointer(out_slice, operand_shape));
1094
1095 // TODO(b/63762267): Be more aggressive about specifying alignment.
1096 MemCpy(operand_ptrs.back(), /*DstAlign=*/llvm::Align(1), in_ptr,
1097 /*SrcAlign=*/llvm::Align(1), ShapeUtil::ByteSizeOf(operand_shape));
1098 }
1099 llvm_ir::EmitTuple(GetIrArrayFor(crs), operand_ptrs, &b_);
1100 return Status::OK();
1101 }
1102
HandleAllReduceMultipleReplica(HloInstruction * crs)1103 Status IrEmitter::HandleAllReduceMultipleReplica(HloInstruction* crs) {
1104 CHECK_GE(crs->operand_count(), 1);
1105 PrimitiveType datatype = crs->operand(0)->shape().element_type();
1106 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(crs));
1107
1108 bool is_datatype_supported = [&] {
1109 // TODO(cheshire): Fix duplication wrt. cpu_runtime
1110 switch (datatype) {
1111 case PRED:
1112 case S8:
1113 case U8:
1114 case S32:
1115 case U32:
1116 case S64:
1117 case U64:
1118 case F16:
1119 case F32:
1120 case F64:
1121 return true;
1122 default:
1123 return false;
1124 }
1125 }();
1126
1127 if (!is_datatype_supported) {
1128 return Unimplemented("AllReduce for datatype '%s' is not supported",
1129 primitive_util::LowercasePrimitiveTypeName(datatype));
1130 }
1131
1132 if (!MatchReductionComputation(crs->to_apply()).has_value()) {
1133 return Unimplemented("AllReduce for computation '%s' is not supported",
1134 crs->to_apply()->ToString());
1135 }
1136
1137 std::string replica_groups = ReplicaGroupsToString(crs->replica_groups());
1138 int32_t replica_groups_size = replica_groups.size();
1139 llvm::Value* replica_groups_v = b_.CreateGlobalStringPtr(replica_groups);
1140
1141 bool is_tuple = crs->operand_count() > 1;
1142 std::vector<llvm::Value*> input_buffer_ptrs;
1143 std::vector<llvm::Value*> output_buffer_ptrs;
1144 if (is_tuple) {
1145 CHECK(crs->shape().IsTuple());
1146
1147 for (int64_t i = 0; i < crs->operand_count(); i++) {
1148 const HloInstruction* op = crs->operand(i);
1149 TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice out_slice,
1150 assignment_.GetUniqueSlice(crs, {i}));
1151 const Shape& operand_shape = crs->operand(i)->shape();
1152 CHECK(operand_shape.IsArray())
1153 << "Operands to all-reduce must be arrays: " << crs->ToString();
1154 output_buffer_ptrs.push_back(EmitBufferPointer(out_slice, operand_shape));
1155 input_buffer_ptrs.push_back(GetEmittedValueFor(op));
1156 }
1157 } else {
1158 Shape shape = crs->operand(0)->shape();
1159 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice input_slice,
1160 assignment_.GetUniqueSlice(crs->operand(0), {}));
1161 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice output_slice,
1162 assignment_.GetUniqueSlice(crs, {}));
1163 input_buffer_ptrs.push_back(EmitBufferPointer(input_slice, shape));
1164 output_buffer_ptrs.push_back(EmitBufferPointer(output_slice, shape));
1165 }
1166
1167 llvm::Value* input_buffers =
1168 EncodeArrayFunctionArguments(input_buffer_ptrs, "input_buffers", &b_);
1169 llvm::Value* output_buffers =
1170 EncodeArrayFunctionArguments(output_buffer_ptrs, "output_buffers", &b_);
1171
1172 int32_t shape_length;
1173 TF_ASSIGN_OR_RETURN(llvm::Value * shape_ptr,
1174 llvm_ir::EncodeSelfDescribingShapeConstant(
1175 crs->shape(), &shape_length, &b_));
1176
1177 llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext());
1178 EmitCallToFunc(
1179 runtime::kAllReduceSymbolName,
1180 {/*run_options=*/GetExecutableRunOptionsArgument(),
1181 /*replica_groups=*/replica_groups_v,
1182 /*replica_groups_size=*/b_.getInt32(replica_groups_size),
1183
1184 /*channel_id_present=*/
1185 b_.getInt32(static_cast<int32>(crs->channel_id().has_value())),
1186 /*op_id=*/
1187 b_.getInt64(crs->channel_id().has_value()
1188 ? *crs->channel_id()
1189 : crs->GetModule()->unique_id()),
1190 /*reduction_kind=*/
1191 b_.getInt32(
1192 static_cast<int32>(*MatchReductionComputation(crs->to_apply()))),
1193 /*shape_ptr=*/shape_ptr,
1194 /*shape_length=*/b_.getInt32(shape_length),
1195 /*num_buffers=*/b_.getInt32(crs->operand_count()),
1196 /*input_buffers=*/b_.CreateBitCast(input_buffers, i8_ptr_type),
1197 /*output_buffers=*/b_.CreateBitCast(output_buffers, i8_ptr_type)},
1198 b_.getVoidTy());
1199
1200 return Status::OK();
1201 }
1202
HandleAllReduce(HloInstruction * crs)1203 Status IrEmitter::HandleAllReduce(HloInstruction* crs) {
1204 if (hlo_module_config_.replica_count() == 1) {
1205 return HandleAllReduceSingleReplica(crs);
1206 }
1207 return HandleAllReduceMultipleReplica(crs);
1208 }
1209
HandleAllToAll(HloInstruction * instruction)1210 Status IrEmitter::HandleAllToAll(HloInstruction* instruction) {
1211 auto* instr = Cast<HloAllToAllInstruction>(instruction);
1212 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(instruction));
1213 CHECK(!instr->split_dimension() && instr->shape().IsTuple())
1214 << "Only tuple AllToAll is supported";
1215
1216 llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext());
1217 std::string replica_groups =
1218 ReplicaGroupsToString(instruction->replica_groups());
1219 int32_t replica_groups_size = replica_groups.size();
1220 llvm::Value* replica_groups_v = b_.CreateGlobalStringPtr(replica_groups);
1221
1222 int64_t buffer_size = -1;
1223 std::vector<llvm::Value*> input_buffer_ptrs;
1224 std::vector<llvm::Value*> output_buffer_ptrs;
1225
1226 for (int64_t i = 0; i < instruction->operand_count(); i++) {
1227 const HloInstruction* op = instruction->operand(i);
1228 TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice out_slice,
1229 assignment_.GetUniqueSlice(instruction, {i}));
1230 const Shape& operand_shape = instruction->operand(i)->shape();
1231 CHECK(operand_shape.IsArray())
1232 << "Operands to all-to-all must be arrays: " << instruction->ToString();
1233 output_buffer_ptrs.push_back(EmitBufferPointer(out_slice, operand_shape));
1234 input_buffer_ptrs.push_back(GetEmittedValueFor(op));
1235 CHECK(buffer_size == -1 || buffer_size == out_slice.size());
1236 buffer_size = out_slice.size();
1237 }
1238
1239 llvm::Value* input_buffers =
1240 EncodeArrayFunctionArguments(input_buffer_ptrs, "input_buffers", &b_);
1241 llvm::Value* output_buffers =
1242 EncodeArrayFunctionArguments(output_buffer_ptrs, "output_buffers", &b_);
1243
1244 EmitCallToFunc(
1245 runtime::kAllToAllSymbolName,
1246 {/*run_options=*/GetExecutableRunOptionsArgument(),
1247 /*channel_id_present=*/
1248 b_.getInt32(static_cast<int32>(instruction->channel_id().has_value())),
1249 /*op_id=*/
1250 b_.getInt64(instruction->channel_id().has_value()
1251 ? *instruction->channel_id()
1252 : instruction->GetModule()->unique_id()),
1253 /*replica_groups=*/replica_groups_v,
1254 /*replica_groups_size=*/b_.getInt32(replica_groups_size),
1255 /*num_buffers=*/b_.getInt32(instruction->operand_count()),
1256 /*buffer_size=*/b_.getInt64(buffer_size),
1257 /*source_buffers=*/b_.CreateBitCast(input_buffers, i8_ptr_type),
1258 /*destination_buffers=*/b_.CreateBitCast(output_buffers, i8_ptr_type)},
1259 b_.getVoidTy());
1260
1261 llvm_ir::EmitTuple(GetIrArrayFor(instruction), output_buffer_ptrs, &b_);
1262 return Status::OK();
1263 }
1264
HandleCollectivePermute(HloInstruction * crs)1265 Status IrEmitter::HandleCollectivePermute(HloInstruction* crs) {
1266 auto* instr = Cast<HloCollectivePermuteInstruction>(crs);
1267 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(instr));
1268 std::string source_target_pairs = absl::StrJoin(
1269 instr->source_target_pairs(), ",", absl::PairFormatter("="));
1270 llvm::Value* source_target_pairs_v =
1271 b_.CreateGlobalStringPtr(source_target_pairs);
1272
1273 Shape shape = crs->operand(0)->shape();
1274
1275 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice input_slice,
1276 assignment_.GetUniqueSlice(crs->operand(0), {}));
1277 llvm::Value* input_buffer = EmitBufferPointer(input_slice, shape);
1278
1279 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice output_slice,
1280 assignment_.GetUniqueSlice(crs, {}));
1281 llvm::Value* output_buffer = EmitBufferPointer(output_slice, shape);
1282
1283 llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext());
1284 EmitCallToFunc(
1285 runtime::kCollectivePermuteSymbolName,
1286 {/*run_options=*/GetExecutableRunOptionsArgument(),
1287 /*channel_id_present=*/
1288 b_.getInt32(static_cast<int32>(crs->channel_id().has_value())),
1289 /*op_id=*/
1290 b_.getInt64(crs->channel_id().has_value()
1291 ? *crs->channel_id()
1292 : crs->GetModule()->unique_id()),
1293 /*byte_size=*/b_.getInt32(ShapeUtil::ByteSizeOf(shape)),
1294 /*input_buffer=*/b_.CreateBitCast(input_buffer, i8_ptr_type),
1295 /*output_buffer=*/b_.CreateBitCast(output_buffer, i8_ptr_type),
1296 /*source_target_pairs=*/source_target_pairs_v,
1297 /*source_target_pairs_size=*/b_.getInt32(source_target_pairs.size())},
1298 b_.getVoidTy());
1299
1300 return Status::OK();
1301 }
1302
HandleReplicaId(HloInstruction * hlo)1303 Status IrEmitter::HandleReplicaId(HloInstruction* hlo) {
1304 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(hlo));
1305 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice output_slice,
1306 assignment_.GetUniqueSlice(hlo, {}));
1307 llvm::Value* output_buffer = EmitBufferPointer(output_slice, hlo->shape());
1308 llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext());
1309 EmitCallToFunc(
1310 runtime::kReplicaIdSymbolName,
1311 {/*run_options=*/GetExecutableRunOptionsArgument(),
1312 /*output_buffer=*/b_.CreateBitCast(output_buffer, i8_ptr_type)},
1313 b_.getVoidTy());
1314 return Status::OK();
1315 }
1316
HandleParameter(HloInstruction * parameter)1317 Status IrEmitter::HandleParameter(HloInstruction* parameter) {
1318 VLOG(2) << "HandleParameter: " << parameter->ToString();
1319 return EmitTargetAddressForOp(parameter);
1320 }
1321
1322 // Returns true if the relative order of the unreduced dimensions stays the same
1323 // through the reduce operation.
ReductionPreservesLayout(const HloInstruction & reduce)1324 static bool ReductionPreservesLayout(const HloInstruction& reduce) {
1325 DCHECK_EQ(reduce.opcode(), HloOpcode::kReduce);
1326
1327 // Maps dimensions that were not reduced from their dimension numbers in the
1328 // source shape to their dimensions numbers in the destination shape.
1329 //
1330 // So if we reduce f32[A,B,C,D] on dimensions 1 and 2, this map contains
1331 // [0->0, 3->1].
1332 absl::flat_hash_map<int64, int64> unreduced_dim_map;
1333
1334 absl::flat_hash_set<int64> reduced_dims(reduce.dimensions().begin(),
1335 reduce.dimensions().end());
1336
1337 const Shape& operand_shape = reduce.operand(0)->shape();
1338 const Shape& result_shape = reduce.shape();
1339
1340 int64_t delta = 0;
1341 for (int64_t i = 0; i < operand_shape.dimensions_size(); i++) {
1342 if (reduced_dims.contains(i)) {
1343 delta++;
1344 } else {
1345 InsertOrDie(&unreduced_dim_map, i, i - delta);
1346 }
1347 }
1348
1349 // Iterate dimensions minor to major and check that the corresponding
1350 // dimensions in the source and target shapes are equivalent.
1351 int64_t result_dim_idx = 0;
1352 for (int64_t operand_dim_idx = 0;
1353 operand_dim_idx < operand_shape.dimensions_size(); operand_dim_idx++) {
1354 int64_t operand_dim =
1355 operand_shape.layout().minor_to_major(operand_dim_idx);
1356 if (!reduced_dims.contains(operand_dim)) {
1357 if (FindOrDie(unreduced_dim_map, operand_dim) !=
1358 result_shape.layout().minor_to_major(result_dim_idx++)) {
1359 return false;
1360 }
1361 }
1362 }
1363
1364 CHECK_EQ(result_dim_idx, result_shape.dimensions_size());
1365
1366 return true;
1367 }
1368
MatchReductionGenerator(HloComputation * function,string * failure_reason) const1369 IrEmitter::ReductionGenerator IrEmitter::MatchReductionGenerator(
1370 HloComputation* function, string* failure_reason) const {
1371 CHECK_EQ(function->num_parameters(), 2);
1372
1373 auto root_instruction = function->root_instruction();
1374 CHECK(ShapeUtil::IsScalar(root_instruction->shape()));
1375
1376 if (root_instruction->operand_count() != 2) {
1377 *failure_reason = "root instruction is not a binary operation";
1378 return nullptr;
1379 }
1380
1381 const Shape& root_shape = root_instruction->shape();
1382 if (ShapeUtil::ElementIsComplex(root_shape)) {
1383 // TODO(b/65408531): Complex add could by done via bitcast to <float x [2N]>
1384 // Complex multiply would be more challenging. We could perhaps use a
1385 // strided load to get all reals in a vector, all images in a vector, or use
1386 // CreateShuffleVector on a bitcast to float x [2N].
1387 *failure_reason = "complex values not supported";
1388 return nullptr;
1389 }
1390 bool root_is_floating_point = ShapeUtil::ElementIsFloating(root_shape);
1391 bool root_is_integral = ShapeUtil::ElementIsIntegral(root_shape);
1392 bool root_is_signed = ShapeUtil::ElementIsSigned(root_shape);
1393
1394 auto lhs = root_instruction->operand(0);
1395 auto rhs = root_instruction->operand(1);
1396
1397 auto param_0 = function->parameter_instruction(0);
1398 auto param_1 = function->parameter_instruction(1);
1399 if (!(lhs == param_0 && rhs == param_1) &&
1400 !(rhs == param_0 && lhs == param_1)) {
1401 *failure_reason =
1402 "root instruction is not a binary operation on the incoming arguments";
1403 return nullptr;
1404 }
1405
1406 CHECK(ShapeUtil::IsScalar(lhs->shape()) && ShapeUtil::IsScalar(rhs->shape()));
1407
1408 // This is visually similar to ElementalIrEmitter, though conceptually we're
1409 // doing something different here. ElementalIrEmitter emits scalar operations
1410 // while these emit scalar or vector operations depending on the type of the
1411 // operands. See CreateShardedVectorType for the actual types in use here.
1412 switch (root_instruction->opcode()) {
1413 default:
1414 *failure_reason = "did not recognize root instruction opcode";
1415 return nullptr;
1416
1417 case HloOpcode::kAdd:
1418 return [root_is_integral](llvm::IRBuilder<>* b, llvm::Value* lhs,
1419 llvm::Value* rhs) {
1420 return root_is_integral ? b->CreateAdd(lhs, rhs)
1421 : b->CreateFAdd(lhs, rhs);
1422 };
1423
1424 case HloOpcode::kMultiply:
1425 return [root_is_integral](llvm::IRBuilder<>* b, llvm::Value* lhs,
1426 llvm::Value* rhs) {
1427 return root_is_integral ? b->CreateMul(lhs, rhs)
1428 : b->CreateFMul(lhs, rhs);
1429 };
1430
1431 case HloOpcode::kAnd:
1432 return [](llvm::IRBuilder<>* b, llvm::Value* lhs, llvm::Value* rhs) {
1433 return b->CreateAnd(lhs, rhs);
1434 };
1435
1436 case HloOpcode::kOr:
1437 return [](llvm::IRBuilder<>* b, llvm::Value* lhs, llvm::Value* rhs) {
1438 return b->CreateOr(lhs, rhs);
1439 };
1440
1441 case HloOpcode::kXor:
1442 return [](llvm::IRBuilder<>* b, llvm::Value* lhs, llvm::Value* rhs) {
1443 return b->CreateXor(lhs, rhs);
1444 };
1445
1446 case HloOpcode::kMaximum:
1447 return [root_is_floating_point, root_is_signed, this](
1448 llvm::IRBuilder<>* b, llvm::Value* lhs,
1449 llvm::Value* rhs) -> llvm::Value* {
1450 if (root_is_floating_point) {
1451 return llvm_ir::EmitFloatMax(
1452 lhs, rhs, b,
1453 hlo_module_config_.debug_options().xla_cpu_enable_fast_min_max());
1454 }
1455
1456 return b->CreateSelect(
1457 b->CreateICmp(root_is_signed ? llvm::ICmpInst::ICMP_SGE
1458 : llvm::ICmpInst::ICMP_UGE,
1459 lhs, rhs),
1460 lhs, rhs);
1461 };
1462
1463 case HloOpcode::kMinimum:
1464 return [root_is_floating_point, root_is_signed, this](
1465 llvm::IRBuilder<>* b, llvm::Value* lhs,
1466 llvm::Value* rhs) -> llvm::Value* {
1467 if (root_is_floating_point) {
1468 return llvm_ir::EmitFloatMin(
1469 lhs, rhs, b,
1470 hlo_module_config_.debug_options().xla_cpu_enable_fast_min_max());
1471 }
1472
1473 return b->CreateSelect(
1474 b->CreateICmp(root_is_signed ? llvm::ICmpInst::ICMP_SLE
1475 : llvm::ICmpInst::ICMP_ULE,
1476 lhs, rhs),
1477 lhs, rhs);
1478 };
1479 }
1480 }
1481
CreateShardedVectorType(PrimitiveType element_type,unsigned element_count)1482 IrEmitter::ShardedVectorType IrEmitter::CreateShardedVectorType(
1483 PrimitiveType element_type, unsigned element_count) {
1484 int vector_register_size_in_elements =
1485 target_machine_features_.vector_register_byte_size(
1486 *compute_function_->function()) /
1487 ShapeUtil::ByteSizeOfPrimitiveType(element_type);
1488
1489 ShardedVectorType sharded_vector_type;
1490 llvm::Type* element_ir_type =
1491 llvm_ir::PrimitiveTypeToIrType(element_type, module_);
1492
1493 for (int i = 0, e = 1 + tensorflow::Log2Ceiling(element_count); i < e; i++) {
1494 // For every power of two present in element_count, we generate one or more
1495 // vector or scalar types.
1496 const unsigned current_size_fragment = 1u << i;
1497 if (!(element_count & current_size_fragment)) {
1498 // Power of two not present in element_count.
1499 continue;
1500 }
1501
1502 if (current_size_fragment == 1) {
1503 // Single element, use a scalar type.
1504 sharded_vector_type.push_back(element_ir_type);
1505 continue;
1506 }
1507
1508 // Lower "current_size_fragment" number of elements using (as few as
1509 // possible) vector registers.
1510
1511 if (current_size_fragment >= vector_register_size_in_elements) {
1512 auto vector_type = llvm::VectorType::get(
1513 element_ir_type, vector_register_size_in_elements, false);
1514 sharded_vector_type.insert(
1515 sharded_vector_type.end(),
1516 current_size_fragment / vector_register_size_in_elements,
1517 vector_type);
1518
1519 // Both current_size_fragment and vector_register_size_in_elements are
1520 // powers of two.
1521 CHECK_EQ(current_size_fragment % vector_register_size_in_elements, 0);
1522 continue;
1523 }
1524
1525 // For now we assume that vector_register_size_in_elements and lower powers
1526 // of two are all legal vector sizes (or at least can be lowered easily by
1527 // LLVM).
1528 sharded_vector_type.push_back(
1529 llvm::VectorType::get(element_ir_type, current_size_fragment, false));
1530 }
1531 return sharded_vector_type;
1532 }
1533
1534 StatusOr<IrEmitter::ShardedVector>
EmitInnerLoopForVectorizedReduction(const ReductionGenerator & reduction_generator,const llvm_ir::IrArray::Index & output_index,const ShardedVectorType & accumulator_type,HloInstruction * init_value,HloInstruction * arg,absl::Span<const int64> dimensions,llvm::Align element_alignment)1535 IrEmitter::EmitInnerLoopForVectorizedReduction(
1536 const ReductionGenerator& reduction_generator,
1537 const llvm_ir::IrArray::Index& output_index,
1538 const ShardedVectorType& accumulator_type, HloInstruction* init_value,
1539 HloInstruction* arg, absl::Span<const int64> dimensions,
1540 llvm::Align element_alignment) {
1541 ShardedVector accumulator;
1542 accumulator.reserve(accumulator_type.size());
1543 for (auto accumulator_shard_type : accumulator_type) {
1544 accumulator.push_back(llvm_ir::EmitAllocaAtFunctionEntry(
1545 accumulator_shard_type, "accumulator", &b_, 0));
1546 }
1547
1548 llvm::Value* init_value_ssa = Load(GetEmittedValueFor(init_value));
1549
1550 for (llvm::Value* accumulator_shard : accumulator) {
1551 llvm::Value* initial_value;
1552 auto shard_type = accumulator_shard->getType()->getPointerElementType();
1553 if (auto vector_type = llvm::dyn_cast<llvm::VectorType>(shard_type)) {
1554 initial_value =
1555 VectorSplat(vector_type->getElementCount(), init_value_ssa);
1556 } else {
1557 initial_value = init_value_ssa;
1558 }
1559
1560 AlignedStore(initial_value, accumulator_shard, element_alignment);
1561 }
1562
1563 llvm_ir::ForLoopNest reduction_loop_nest(IrName(arg, "vectorized_inner"),
1564 &b_);
1565 std::vector<llvm::Value*> input_multi_index =
1566 reduction_loop_nest.AddLoopsForShapeOnDimensions(arg->shape(), dimensions,
1567 "reduction_dim");
1568
1569 SetToFirstInsertPoint(reduction_loop_nest.GetInnerLoopBodyBasicBlock(), &b_);
1570
1571 llvm_ir::IrArray arg_array(GetIrArrayFor(arg));
1572 llvm_ir::IrArray::Index::const_iterator it = output_index.begin();
1573
1574 for (auto& i : input_multi_index) {
1575 if (i == nullptr) {
1576 i = *it++;
1577 }
1578 }
1579 CHECK(output_index.end() == it);
1580 llvm_ir::IrArray::Index input_index(input_multi_index, arg->shape(),
1581 b_.getInt64Ty());
1582
1583 llvm::Value* input_address = BitCast(
1584 arg_array.EmitArrayElementAddress(input_index, &b_), b_.getInt8PtrTy());
1585
1586 for (int i = 0; i < accumulator.size(); i++) {
1587 auto input_address_typed =
1588 BitCast(input_address, accumulator[i]->getType());
1589 auto current_accumulator_value =
1590 AlignedLoad(accumulator[i], element_alignment);
1591 auto addend = AlignedLoad(input_address_typed, element_alignment);
1592 arg_array.AnnotateLoadStoreInstructionWithMetadata(addend);
1593
1594 auto reduced_result =
1595 reduction_generator(&b_, current_accumulator_value, addend);
1596 AlignedStore(reduced_result, accumulator[i], element_alignment);
1597
1598 if (i != (accumulator.size() - 1)) {
1599 input_address = ConstInBoundsGEP1_32(reduced_result->getType(),
1600 input_address_typed, 1);
1601 }
1602 }
1603
1604 SetToFirstInsertPoint(reduction_loop_nest.GetOuterLoopExitBasicBlock(), &b_);
1605
1606 ShardedVector result_ssa;
1607 result_ssa.reserve(accumulator.size());
1608 for (auto accumulator_shard : accumulator) {
1609 result_ssa.push_back(AlignedLoad(accumulator_shard, element_alignment));
1610 }
1611 return result_ssa;
1612 }
1613
EmitShardedVectorStore(llvm::Value * store_address,const std::vector<llvm::Value * > & value_to_store,llvm::Align alignment,const llvm_ir::IrArray & containing_array)1614 void IrEmitter::EmitShardedVectorStore(
1615 llvm::Value* store_address, const std::vector<llvm::Value*>& value_to_store,
1616 llvm::Align alignment, const llvm_ir::IrArray& containing_array) {
1617 for (int i = 0; i < value_to_store.size(); i++) {
1618 auto store_address_typed =
1619 BitCast(store_address,
1620 llvm::PointerType::getUnqual(value_to_store[i]->getType()));
1621
1622 auto store_instruction =
1623 AlignedStore(value_to_store[i], store_address_typed, alignment);
1624 containing_array.AnnotateLoadStoreInstructionWithMetadata(
1625 store_instruction);
1626
1627 if (i != (value_to_store.size() - 1)) {
1628 store_address = ConstInBoundsGEP1_32(value_to_store[i]->getType(),
1629 store_address_typed, 1);
1630 }
1631 }
1632 }
1633
EmitVectorizedReduce(HloInstruction * reduce,HloInstruction * arg,HloInstruction * init_value,absl::Span<const int64> dimensions,HloComputation * function,string * failure_reason)1634 StatusOr<bool> IrEmitter::EmitVectorizedReduce(
1635 HloInstruction* reduce, HloInstruction* arg, HloInstruction* init_value,
1636 absl::Span<const int64> dimensions, HloComputation* function,
1637 string* failure_reason) {
1638 if (!reduce->shape().IsArray()) {
1639 *failure_reason = "vectorization of variadic reduce not implemented";
1640 return false;
1641 }
1642
1643 if (!ReductionPreservesLayout(*reduce)) {
1644 return false;
1645 }
1646
1647 ReductionGenerator reduction_generator =
1648 MatchReductionGenerator(function, failure_reason);
1649 if (!reduction_generator) {
1650 return false;
1651 }
1652
1653 int vector_register_size_in_elements =
1654 target_machine_features_.vector_register_byte_size(
1655 *compute_function_->function()) /
1656 ShapeUtil::ByteSizeOfPrimitiveType(reduce->shape().element_type());
1657 if (vector_register_size_in_elements == 0) {
1658 // Either we don't know the vector register width for the target or the
1659 // vector register is smaller than the size of the primitive type.
1660 return false;
1661 }
1662
1663 int vectorization_factor_in_bytes =
1664 target_machine_features_.vectorization_factor_in_bytes();
1665
1666 // We try to process vectorization_factor elements at the same time.
1667 const int vectorization_factor =
1668 vectorization_factor_in_bytes /
1669 ShapeUtil::ByteSizeOfPrimitiveType(reduce->shape().element_type());
1670
1671 bool is_reduction_over_minor_dimension = absl::c_linear_search(
1672 dimensions, LayoutUtil::Minor(arg->shape().layout(), 0));
1673
1674 llvm::Align element_alignment(tensorflow::MathUtil::GCD<unsigned>(
1675 ShapeUtil::ByteSizeOfPrimitiveType(reduce->shape().element_type()),
1676 MinimumAlignmentForPrimitiveType(reduce->shape().element_type())));
1677
1678 if (is_reduction_over_minor_dimension) {
1679 // TODO(sanjoy): Implement vectorized reduction over the minor dimension.
1680 *failure_reason = "reduction over minor dimension not implemented";
1681 return false;
1682 }
1683
1684 CHECK(!reduce->shape().IsTuple());
1685 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(reduce));
1686
1687 // We know we're not reducing over the most minor dimension, which means we
1688 // can lower the reduction loop as:
1689 //
1690 // 1. We're reducing over dimensions R0, R1.
1691 // 2. D0 is the most minor dimension.
1692 // 3. VS is the vectorization stride (we want to reduce this many elements at
1693 // once)
1694 //
1695 // for (d1 in D1) {
1696 // for (d0 in D0 with stride VS) {
1697 // vector_acc = init
1698 // for (r1 in R1) {
1699 // for (r0 in R0) {
1700 // vector_acc = elementwise_reduce(vector_acc, input[d1, d0, r1, r0]
1701 // }
1702 // }
1703 // output[d1, d0] = vector_acc
1704 // }
1705 // }
1706
1707 llvm_ir::ForLoopNest loop_nest(IrName(reduce), &b_);
1708 std::vector<llvm::Value*> array_multi_index(
1709 reduce->shape().dimensions_size());
1710 for (int i = LayoutUtil::MinorToMajor(reduce->shape()).size() - 1; i > 0;
1711 --i) {
1712 int64_t dimension = LayoutUtil::Minor(reduce->shape().layout(), i);
1713 int64_t start_index = 0;
1714 int64_t end_index = reduce->shape().dimensions(dimension);
1715 std::unique_ptr<llvm_ir::ForLoop> loop = loop_nest.AddLoop(
1716 start_index, end_index, absl::StrFormat("dim.%d", dimension));
1717 array_multi_index[dimension] = loop->GetIndVarValue();
1718 }
1719
1720 int64_t innermost_dimension = LayoutUtil::Minor(reduce->shape().layout(), 0);
1721 int64_t innermost_dimension_size =
1722 reduce->shape().dimensions(innermost_dimension);
1723
1724 if (llvm::BasicBlock* innermost_body_bb =
1725 loop_nest.GetInnerLoopBodyBasicBlock()) {
1726 SetToFirstInsertPoint(innermost_body_bb, &b_);
1727 }
1728
1729 auto outermost_loop_exit_block = loop_nest.GetOuterLoopExitBasicBlock();
1730
1731 if (innermost_dimension_size >= vectorization_factor) {
1732 int64_t start_index = 0;
1733 int64_t end_index = (innermost_dimension_size / vectorization_factor) *
1734 vectorization_factor;
1735 std::unique_ptr<llvm_ir::ForLoop> loop =
1736 loop_nest.AddLoop(start_index, end_index, vectorization_factor,
1737 absl::StrFormat("dim.%d", innermost_dimension));
1738 array_multi_index[innermost_dimension] = loop->GetIndVarValue();
1739
1740 SetToFirstInsertPoint(loop->GetBodyBasicBlock(), &b_);
1741
1742 ShardedVectorType vector_type = CreateShardedVectorType(
1743 reduce->shape().element_type(), vectorization_factor);
1744 llvm_ir::IrArray::Index array_index(array_multi_index, reduce->shape(),
1745 b_.getInt64Ty());
1746 TF_ASSIGN_OR_RETURN(std::vector<llvm::Value*> accumulator,
1747 EmitInnerLoopForVectorizedReduction(
1748 reduction_generator, array_index, vector_type,
1749 init_value, arg, dimensions, element_alignment));
1750
1751 llvm_ir::IrArray target_array = GetIrArrayFor(reduce);
1752 llvm::Value* output_address =
1753 target_array.EmitArrayElementAddress(array_index, &b_);
1754 EmitShardedVectorStore(output_address, accumulator, element_alignment,
1755 target_array);
1756
1757 if (auto exit_terminator = loop->GetExitBasicBlock()->getTerminator()) {
1758 CHECK_GT(LayoutUtil::MinorToMajor(reduce->shape()).size(), 1);
1759 b_.SetInsertPoint(exit_terminator);
1760 } else {
1761 CHECK_EQ(LayoutUtil::MinorToMajor(reduce->shape()).size(), 1);
1762 b_.SetInsertPoint(loop->GetExitBasicBlock());
1763 }
1764 }
1765
1766 // Since we increment the stride for the inner dimension by more than 1, we
1767 // may need to peel out an "epilogue" iteration to get the remaining elements
1768 // in the following case:
1769 if (innermost_dimension_size % vectorization_factor) {
1770 // TODO(b/63775531): Consider using a scalar loop here to save on code size.
1771 array_multi_index[innermost_dimension] =
1772 b_.getInt64(innermost_dimension_size -
1773 (innermost_dimension_size % vectorization_factor));
1774
1775 ShardedVectorType vector_type = CreateShardedVectorType(
1776 reduce->shape().element_type(),
1777 innermost_dimension_size % vectorization_factor);
1778 llvm_ir::IrArray::Index array_index(array_multi_index, reduce->shape(),
1779 b_.getInt64Ty());
1780 TF_ASSIGN_OR_RETURN(std::vector<llvm::Value*> accumulator,
1781 EmitInnerLoopForVectorizedReduction(
1782 reduction_generator, array_index, vector_type,
1783 init_value, arg, dimensions, element_alignment));
1784
1785 llvm_ir::IrArray target_array = GetIrArrayFor(reduce);
1786 llvm::Value* output_address =
1787 target_array.EmitArrayElementAddress(array_index, &b_);
1788 EmitShardedVectorStore(output_address, accumulator, element_alignment,
1789 target_array);
1790 }
1791
1792 if (outermost_loop_exit_block) {
1793 b_.SetInsertPoint(outermost_loop_exit_block);
1794 }
1795
1796 return true;
1797 }
1798
HandleReduce(HloInstruction * reduce)1799 Status IrEmitter::HandleReduce(HloInstruction* reduce) {
1800 auto arg = reduce->mutable_operand(0);
1801 auto init_value = reduce->mutable_operand(1);
1802 absl::Span<const int64> dimensions(reduce->dimensions());
1803 HloComputation* function = reduce->to_apply();
1804 if (!options::VectorizedReduceDisabled(hlo_module_config_)) {
1805 string vectorization_failure_reason;
1806 TF_ASSIGN_OR_RETURN(
1807 bool vectorization_successful,
1808 EmitVectorizedReduce(reduce, arg, init_value, dimensions, function,
1809 &vectorization_failure_reason));
1810 if (vectorization_successful) {
1811 VLOG(1) << "Successfully vectorized reduction " << reduce->ToString()
1812 << "\n";
1813 return Status::OK();
1814 } else {
1815 VLOG(1) << "Could not vectorize reduction " << reduce->ToString() << ": "
1816 << vectorization_failure_reason;
1817 }
1818 }
1819
1820 return DefaultAction(reduce);
1821 }
1822
HandleSend(HloInstruction * send)1823 Status IrEmitter::HandleSend(HloInstruction* send) {
1824 // TODO(b/33942983): Support Send/Recv on CPU.
1825 return Unimplemented("Send is not implemented on CPU.");
1826 }
1827
HandleSendDone(HloInstruction * send_done)1828 Status IrEmitter::HandleSendDone(HloInstruction* send_done) {
1829 // TODO(b/33942983): Support Send/Recv on CPU.
1830 return Unimplemented("Send-done is not implemented on CPU.");
1831 }
1832
HandleScatter(HloInstruction *)1833 Status IrEmitter::HandleScatter(HloInstruction*) {
1834 return Unimplemented("Scatter is not implemented on CPUs.");
1835 }
1836
HandleSlice(HloInstruction * slice)1837 Status IrEmitter::HandleSlice(HloInstruction* slice) {
1838 VLOG(2) << "HandleSlice: " << slice->ToString();
1839 auto operand = slice->operand(0);
1840 // The code below emits a sequential loop nest. For the parallel backend, use
1841 // ParallelLoopEmitter which respects dynamic loop bounds.
1842 if (ShouldEmitParallelLoopFor(*slice)) {
1843 return DefaultAction(slice);
1844 }
1845
1846 // The code below assumes the layouts are equal.
1847 if (!LayoutUtil::Equal(operand->shape().layout(), slice->shape().layout())) {
1848 return DefaultAction(slice);
1849 }
1850
1851 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(slice));
1852
1853 if (ShapeUtil::IsZeroElementArray(slice->shape())) {
1854 return Status::OK();
1855 }
1856
1857 const Layout& layout = operand->shape().layout();
1858 const int64_t num_dims = operand->shape().dimensions_size();
1859
1860 // The slice lowering finds maximal contiguous blocks of memory that can be
1861 // copied from the source to the target. This is done by looking at the
1862 // source/target layout in minor to major order and do the following:
1863 //
1864 // * Find an initial segment of dimensions along which the slice uses the
1865 // whole dimension. These are the "inner" dimensions and can be folded into
1866 // the memcpy.
1867 //
1868 // * Of the remaining dimensions decide which ones require loops.
1869 //
1870 // * Implement the memcpy within the innermost loop.
1871
1872 absl::flat_hash_set<int64> inner_dims;
1873 for (int64_t dim : LayoutUtil::MinorToMajor(layout)) {
1874 if (operand->shape().dimensions(dim) != slice->shape().dimensions(dim)) {
1875 break;
1876 }
1877 inner_dims.insert(dim);
1878 }
1879
1880 const bool is_trivial_copy = (inner_dims.size() == num_dims);
1881 if (is_trivial_copy) {
1882 if (ShapeUtil::IsEffectiveScalar(slice->shape())) {
1883 return DefaultAction(slice);
1884 } else {
1885 return EmitMemcpy(*slice, *operand);
1886 }
1887 }
1888
1889 // The memcpy will copy elements that are logically this shape (allowed to be
1890 // scalar).
1891 const Shape logical_element_shape = ShapeUtil::FilterDimensions(
1892 [&inner_dims](int64_t dim) { return inner_dims.contains(dim); },
1893 operand->shape());
1894
1895 const int64_t primitive_elements_per_logical_element =
1896 ShapeUtil::ElementsIn(logical_element_shape);
1897
1898 // memcpy_dim is the innermost (in terms of layout) dimension for which the
1899 // slice does *not* just copy all the elements along the dimension.
1900 const int64_t memcpy_dim = LayoutUtil::Minor(layout, inner_dims.size());
1901
1902 const bool memcpy_is_contiguous = slice->slice_strides(memcpy_dim) == 1;
1903 // The number of logical elements that can be copied in a single call
1904 // to memcpy. We can only copy 1 element at a time if there is a non-trivial
1905 // stride.
1906 const int64_t memcpy_logical_elements =
1907 memcpy_is_contiguous
1908 ? slice->slice_limits(memcpy_dim) - slice->slice_starts(memcpy_dim)
1909 : 1;
1910
1911 // Determine the dimensions that get lowered as loops.
1912 std::vector<int64> outer_dims;
1913 for (int64_t i = 0; i < num_dims - inner_dims.size() - 1; ++i) {
1914 outer_dims.push_back(LayoutUtil::Major(layout, i));
1915 }
1916
1917 // Is the slice along the memcpy dimension contiguous? If not, then memcpy_dim
1918 // needs to be wrapped around a loop as well.
1919 if (!memcpy_is_contiguous) {
1920 outer_dims.push_back(memcpy_dim);
1921 }
1922
1923 llvm_ir::IrArray target_array = GetIrArrayFor(slice);
1924
1925 const int64_t num_outer_loops = outer_dims.size();
1926 llvm_ir::ForLoopNest loops(IrName(slice), &b_);
1927 std::vector<llvm::Value*> target_multi_index =
1928 loops.AddLoopsForShapeOnDimensions(slice->shape(), outer_dims, "slice");
1929
1930 // Only the indices for the outer dimensions have been initialized in
1931 // target_index. The rest of the indices should get initialized to 0, since
1932 // for the rest of the dimensions the copy writes to the full dimension.
1933 std::replace(target_multi_index.begin(), target_multi_index.end(),
1934 static_cast<llvm::Value*>(nullptr),
1935 static_cast<llvm::Value*>(b_.getInt64(0)));
1936 llvm_ir::IrArray::Index target_index(target_multi_index, slice->shape(),
1937 b_.getInt64Ty());
1938
1939 if (num_outer_loops > 0) {
1940 SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_);
1941 }
1942
1943 llvm_ir::IrArray source_array = GetIrArrayFor(operand);
1944 const llvm_ir::IrArray::Index source_index = target_index.SourceIndexOfSlice(
1945 /*operand_shape=*/operand->shape(), /*starts=*/slice->slice_starts(),
1946 /*strides=*/slice->slice_strides(), /*builder=*/&b_);
1947
1948 llvm::Value* memcpy_dest =
1949 target_array.EmitArrayElementAddress(target_index, &b_, "slice.dest");
1950 llvm::Value* memcpy_source =
1951 source_array.EmitArrayElementAddress(source_index, &b_, "slice.source");
1952
1953 const int64_t memcpy_elements =
1954 primitive_elements_per_logical_element * memcpy_logical_elements;
1955
1956 EmitTransferElements(memcpy_dest, memcpy_source, memcpy_elements,
1957 slice->shape().element_type(), target_array,
1958 source_array);
1959
1960 if (VLOG_IS_ON(2)) {
1961 const int64_t memcpy_bytes =
1962 ShapeUtil::ByteSizeOf(logical_element_shape) * memcpy_elements;
1963 VLOG(2) << " emitted copy of " << memcpy_bytes << " bytes inside "
1964 << num_outer_loops << " loops";
1965 }
1966
1967 if (num_outer_loops > 0) {
1968 SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
1969 }
1970
1971 return Status::OK();
1972 }
1973
HandleDynamicSlice(HloInstruction * dynamic_slice)1974 Status IrEmitter::HandleDynamicSlice(HloInstruction* dynamic_slice) {
1975 if (ShapeUtil::IsScalar(dynamic_slice->shape())) {
1976 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(dynamic_slice));
1977 return EmitMemcpy(*dynamic_slice->operand(0), *dynamic_slice);
1978 }
1979 return DefaultAction(dynamic_slice);
1980 }
1981
HandleDynamicUpdateSlice(HloInstruction * dynamic_update_slice)1982 Status IrEmitter::HandleDynamicUpdateSlice(
1983 HloInstruction* dynamic_update_slice) {
1984 auto update = dynamic_update_slice->operand(1);
1985 if (ShapeUtil::IsScalar(dynamic_update_slice->shape())) {
1986 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(dynamic_update_slice));
1987 return EmitMemcpy(*update, *dynamic_update_slice);
1988 } else if (llvm_ir::CanUpdateDynamicSliceInPlace(dynamic_update_slice,
1989 assignment_)) {
1990 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(dynamic_update_slice));
1991 auto operands = GetIrArraysForOperandsOf(dynamic_update_slice);
1992 return llvm_ir::EmitDynamicUpdateSliceInPlace(
1993 operands, GetIrArrayFor(dynamic_update_slice),
1994 IrName(dynamic_update_slice, "in_place"), &b_);
1995 }
1996 return DefaultAction(dynamic_update_slice);
1997 }
1998
HandleRecv(HloInstruction * recv)1999 Status IrEmitter::HandleRecv(HloInstruction* recv) {
2000 // TODO(b/33942983): Support Send/Recv on CPU.
2001 return Unimplemented("Recv is not implemented on CPU.");
2002 }
2003
HandleRecvDone(HloInstruction * recv_done)2004 Status IrEmitter::HandleRecvDone(HloInstruction* recv_done) {
2005 // TODO(b/33942983): Support Send/Recv on CPU.
2006 return Unimplemented("Recv-done is not implemented on CPU.");
2007 }
2008
HandlePad(HloInstruction * pad)2009 Status IrEmitter::HandlePad(HloInstruction* pad) {
2010 // CPU backend does not properly handle negative padding but this is ok
2011 // because negative padding should be removed by the algebraic simplifier.
2012 for (auto& padding_dimension : pad->padding_config().dimensions()) {
2013 if (padding_dimension.edge_padding_low() < 0 ||
2014 padding_dimension.edge_padding_high() < 0) {
2015 return InternalErrorStrCat(
2016 "Encountered negative padding in IrEmitter on CPU. "
2017 "This should have been eliminated at the HLO level. ",
2018 pad->ToString());
2019 }
2020 }
2021
2022 // First, fill in the padding value to all output elements.
2023 TF_RETURN_IF_ERROR(EmitTargetElementLoop(
2024 pad, "initialize",
2025 [this, pad](const llvm_ir::IrArray::Index& target_index) {
2026 const HloInstruction* padding_value = pad->operand(1);
2027 llvm::Value* padding_value_addr = GetEmittedValueFor(padding_value);
2028 return Load(padding_value_addr);
2029 }));
2030
2031 // Create a loop to iterate over the operand elements and update the output
2032 // locations where the operand elements should be stored.
2033 llvm_ir::ForLoopNest loops(IrName(pad, "assign"), &b_);
2034 const HloInstruction* operand = pad->operand(0);
2035 const llvm_ir::IrArray::Index operand_index =
2036 loops.AddLoopsForShape(operand->shape(), "operand");
2037
2038 SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_);
2039
2040 // Load an element from the operand.
2041 llvm_ir::IrArray operand_array(GetIrArrayFor(operand));
2042 llvm::Value* operand_data =
2043 operand_array.EmitReadArrayElement(operand_index, &b_);
2044
2045 // Compute the output index the operand element should be assigned to.
2046 // output_index := edge_padding_low + operand_index * (interior_padding + 1)
2047 const PaddingConfig& padding_config = pad->padding_config();
2048 std::vector<llvm::Value*> output_multi_index;
2049 for (size_t i = 0; i < operand_index.size(); ++i) {
2050 llvm::Value* offset =
2051 Mul(operand_index[i],
2052 b_.getInt64(padding_config.dimensions(i).interior_padding() + 1));
2053 llvm::Value* index = Add(
2054 offset, b_.getInt64(padding_config.dimensions(i).edge_padding_low()));
2055 output_multi_index.push_back(index);
2056 }
2057
2058 // Store the operand element to the computed output location.
2059 llvm_ir::IrArray output_array(GetIrArrayFor(pad));
2060 llvm_ir::IrArray::Index output_index(
2061 output_multi_index, output_array.GetShape(), operand_index.GetType());
2062 output_array.EmitWriteArrayElement(output_index, operand_data, &b_);
2063
2064 SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
2065 return Status::OK();
2066 }
2067
HandleFusion(HloInstruction * fusion)2068 Status IrEmitter::HandleFusion(HloInstruction* fusion) {
2069 auto* root = fusion->fused_expression_root();
2070 if (llvm_ir::CanEmitFusedDynamicUpdateSliceInPlace(fusion, assignment_)) {
2071 VLOG(3) << "HandleFusion FusedDynamicUpdateSliceInPlace";
2072 CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_);
2073 FusedIrEmitter fused_emitter(&elemental_emitter);
2074 BindFusionArguments(fusion, &fused_emitter);
2075
2076 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(fusion));
2077 // Delegate to common implementation of fused in-place dynamic-update-slice.
2078 return llvm_ir::EmitFusedDynamicUpdateSliceInPlace(
2079 fusion, GetIrArrayFor(fusion), &fused_emitter, &b_);
2080 } else if (fusion->IsLoopFusion()) {
2081 VLOG(3) << "HandleFusion kLoop";
2082 CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_);
2083 FusedIrEmitter fused_emitter(&elemental_emitter);
2084 BindFusionArguments(fusion, &fused_emitter);
2085 TF_ASSIGN_OR_RETURN(auto generator, fused_emitter.GetGenerator(
2086 fusion->fused_expression_root()));
2087 return EmitTargetElementLoop(fusion, generator);
2088 } else if (fusion->IsOutputFusion()) {
2089 VLOG(3) << "HandleFusion kOutput";
2090 int64_t dot_op_index =
2091 root->operand(0)->opcode() == HloOpcode::kDot ? 0 : 1;
2092 const HloInstruction* dot = root->operand(dot_op_index);
2093 CHECK_EQ(dot->opcode(), HloOpcode::kDot)
2094 << dot->ToString() << " "
2095 << fusion->fused_instructions_computation()->ToString();
2096
2097 int64_t dot_lhs_param_number = dot->operand(0)->parameter_number();
2098 int64_t dot_rhs_param_number = dot->operand(1)->parameter_number();
2099 int64_t addend_param_number =
2100 root->operand(1 - dot_op_index)->parameter_number();
2101
2102 Shape target_shape = fusion->shape();
2103 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(fusion));
2104 llvm_ir::IrArray target_array = GetIrArrayFor(fusion);
2105
2106 llvm_ir::IrArray lhs_array(
2107 GetIrArrayFor(fusion->operand(dot_lhs_param_number)));
2108 llvm_ir::IrArray rhs_array(
2109 GetIrArrayFor(fusion->operand(dot_rhs_param_number)));
2110 llvm_ir::IrArray addend_array(
2111 GetIrArrayFor(fusion->operand(addend_param_number)));
2112
2113 TF_RETURN_IF_ERROR(EmitDotOperation(
2114 *dot, target_array, lhs_array, rhs_array, &addend_array,
2115 GetExecutableRunOptionsArgument(), &b_, mlir_context_,
2116 hlo_module_config_, target_machine_features_));
2117 return Status::OK();
2118 } else {
2119 return Unimplemented("Fusion kind not implemented on CPU");
2120 }
2121 }
2122
HandleCall(HloInstruction * call)2123 Status IrEmitter::HandleCall(HloInstruction* call) {
2124 HloComputation* computation = call->to_apply();
2125 llvm::Function* call_ir_function = FindOrDie(emitted_functions_, computation);
2126
2127 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(call));
2128
2129 if (!computation->root_instruction()->outer_dimension_partitions().empty()) {
2130 // ParallelTaskAssignment assigned partitions, emit call to
2131 // ParallelForkJoin.
2132 std::vector<llvm::Value*> call_args = GetArrayFunctionCallArguments(
2133 {}, &b_, computation->name(),
2134 /*return_value_buffer=*/emitted_value_[call],
2135 /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(),
2136 /*buffer_table_arg=*/GetBufferTableArgument(),
2137 /*profile_counters_arg=*/GetProfileCountersArgument());
2138
2139 HloInstruction* root = computation->root_instruction();
2140 TF_RETURN_IF_ERROR(EmitCallToParallelForkJoin(
2141 call_args, root->shape(), root->outer_dimension_partitions(), &b_,
2142 call_ir_function, computation->name()));
2143 } else {
2144 EmitGlobalCall(*computation, computation->name());
2145 }
2146
2147 return Status::OK();
2148 }
2149
HandleSliceToDynamic(HloInstruction * hlo)2150 Status IrEmitter::HandleSliceToDynamic(HloInstruction* hlo) {
2151 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(hlo));
2152 std::vector<llvm::Value*> dynamic_dims;
2153 int32_t raw_data_size =
2154 ShapeUtil::ByteSizeOf(ShapeUtil::MakeStaticShape(hlo->shape()));
2155 llvm::Value* dest_buffer = GetEmittedValueFor(hlo);
2156 llvm::Value* raw_buffer =
2157 b_.CreateBitCast(dest_buffer, b_.getInt8Ty()->getPointerTo());
2158 for (int64_t i = 1; i < hlo->operand_count(); ++i) {
2159 const int64_t dim_index = i - 1;
2160 llvm::Value* source_buffer = GetEmittedValueFor(hlo->operand(i));
2161 llvm::LoadInst* dyn_dim_size = b_.CreateLoad(source_buffer, "dyn_dim_size");
2162
2163 llvm::Value* metadata = b_.CreateConstInBoundsGEP1_32(
2164 b_.getInt8Ty(), raw_buffer, raw_data_size + dim_index * sizeof(int32));
2165 b_.CreateStore(dyn_dim_size,
2166 b_.CreateBitCast(metadata, b_.getInt32Ty()->getPointerTo()));
2167 dynamic_dims.push_back(b_.CreateIntCast(dyn_dim_size, b_.getInt64Ty(),
2168 /*isSigned=*/true,
2169 "i64_dyn_dim_size"));
2170 }
2171
2172 llvm_ir::IrArray data_array = GetIrArrayFor(hlo);
2173 // Pseudo code for sliceToDynamic:
2174 //
2175 // for (index i in dynamic_dim)
2176 // dest_index = delinearize(linearize(i, dynamic_dim), static_dim)
2177 // dest[dest_index] = source[i]
2178 auto loop_body_emitter =
2179 [&](const llvm_ir::IrArray::Index& array_index) -> Status {
2180 llvm::Value* source_element =
2181 GetIrArrayFor(hlo->operand(0)).EmitReadArrayElement(array_index, &b_);
2182 llvm::Value* linear_index = array_index.Linearize(dynamic_dims, &b_);
2183 // Delinearize the index based on the static shape.
2184 llvm_ir::IrArray::Index dest_index(linear_index, data_array.GetShape(),
2185 &b_);
2186 data_array.EmitWriteArrayElement(dest_index, source_element, &b_);
2187 return Status::OK();
2188 };
2189 return llvm_ir::LoopEmitter(loop_body_emitter, data_array.GetShape(),
2190 dynamic_dims, &b_)
2191 .EmitLoop(IrName(hlo));
2192 }
2193
HandlePadToStatic(HloInstruction * hlo)2194 Status IrEmitter::HandlePadToStatic(HloInstruction* hlo) {
2195 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(hlo));
2196
2197 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice data_slice,
2198 assignment_.GetUniqueSlice(hlo, {0}));
2199 std::vector<llvm::Value*> dynamic_dims;
2200 std::vector<llvm::Value*> tuple_operand_ptrs;
2201 const Shape& data_shape = ShapeUtil::GetSubshape(hlo->shape(), {0});
2202 const Shape& input_shape = hlo->operand(0)->shape();
2203 llvm::Value* data_address = EmitBufferPointer(data_slice, data_shape);
2204 llvm_ir::IrArray data_array(data_address, data_shape);
2205 llvm::Value* source_buffer = GetEmittedValueFor(hlo->operand(0));
2206 llvm::Value* raw_buffer =
2207 b_.CreateBitCast(source_buffer, b_.getInt8Ty()->getPointerTo());
2208 int64_t raw_data_size =
2209 ShapeUtil::ByteSizeOf(ShapeUtil::MakeStaticShape(input_shape));
2210
2211 // Put a placeholder for the data array's pointer
2212 tuple_operand_ptrs.push_back(data_array.GetBasePointer());
2213 // PadToStatic has a dynamic tensor as input and variadic size of outputs:
2214 // (static_tensor, dynamic_dim_0, dynamic_dim_1, ... )
2215 // Dynamic dimension sizes starts from output index 1.
2216 for (int64_t i = 1; i < hlo->shape().tuple_shapes_size(); ++i) {
2217 // Read from the metadata section of the dynamic input (operand 0).
2218 const Shape& dim_shape = ShapeUtil::GetSubshape(hlo->shape(), {i});
2219 TF_RET_CHECK(Shape::Equal()(dim_shape, ShapeUtil::MakeScalarShape(S32)));
2220 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice dim_size_slice,
2221 assignment_.GetUniqueSlice(hlo, {i}));
2222 llvm::Value* dest_dim_size_address =
2223 EmitBufferPointer(dim_size_slice, data_shape);
2224 const int64_t dim_index = i - 1;
2225 llvm::Value* metadata = b_.CreateConstInBoundsGEP1_32(
2226 b_.getInt8Ty(), raw_buffer, raw_data_size + dim_index * sizeof(int32));
2227 llvm::Value* dyn_dim_size = b_.CreateLoad(
2228 b_.CreateBitCast(metadata, b_.getInt32Ty()->getPointerTo()),
2229 "dyn_dim_size");
2230 b_.CreateStore(dyn_dim_size,
2231 b_.CreateBitCast(dest_dim_size_address,
2232 b_.getInt32Ty()->getPointerTo()));
2233 dynamic_dims.push_back(b_.CreateIntCast(dyn_dim_size, b_.getInt64Ty(),
2234 /*isSigned=*/true,
2235 "i64_dyn_dim_size"));
2236 tuple_operand_ptrs.push_back(dest_dim_size_address);
2237 }
2238
2239 // Pseudo code for padToStatic:
2240 //
2241 // for (index i in dynamic_dim)
2242 // source_index = delinearize(inearize(i, dynamic_dim), static_dim)
2243 // dest[i] = source[source_index]
2244 auto loop_body_emitter =
2245 [&](const llvm_ir::IrArray::Index& array_index) -> Status {
2246 llvm::Value* linear_index = array_index.Linearize(dynamic_dims, &b_);
2247 llvm_ir::IrArray::Index source_index(linear_index, input_shape, &b_);
2248 llvm::Value* source_element =
2249 GetIrArrayFor(hlo->operand(0)).EmitReadArrayElement(source_index, &b_);
2250 data_array.EmitWriteArrayElement(array_index, source_element, &b_);
2251 return Status::OK();
2252 };
2253 TF_RETURN_IF_ERROR(
2254 llvm_ir::LoopEmitter(loop_body_emitter, input_shape, dynamic_dims, &b_)
2255 .EmitLoop(IrName(hlo)));
2256
2257 // Emit static tensor and dynamic sizes as one tuple.
2258 llvm_ir::EmitTuple(GetIrArrayFor(hlo), tuple_operand_ptrs, &b_);
2259 return Status::OK();
2260 }
2261
HandleTopK(HloInstruction * hlo)2262 Status IrEmitter::HandleTopK(HloInstruction* hlo) {
2263 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(hlo));
2264 const HloInstruction* input = hlo->operand(0);
2265 const int64_t k = hlo->shape().tuple_shapes(0).dimensions().back();
2266 const bool has_batch = hlo->shape().tuple_shapes(0).dimensions_size() == 2;
2267 TF_RET_CHECK(input->shape().element_type() == F32);
2268 TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(
2269 hlo->shape().tuple_shapes(0).layout()));
2270 TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(
2271 hlo->shape().tuple_shapes(1).layout()));
2272 TF_RET_CHECK(
2273 LayoutUtil::IsMonotonicWithDim0Major(hlo->operand(0)->shape().layout()));
2274
2275 TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice values_slice,
2276 assignment_.GetUniqueSlice(hlo->operand(0), {}));
2277 TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice out_values_slice,
2278 assignment_.GetUniqueSlice(hlo, {0}));
2279 TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice out_indices_slice,
2280 assignment_.GetUniqueSlice(hlo, {1}));
2281 llvm::Value* values_ptr =
2282 EmitBufferPointer(values_slice, hlo->operand(0)->shape());
2283 llvm::Value* out_values_ptr =
2284 EmitBufferPointer(out_values_slice, hlo->shape().tuple_shapes(0));
2285 llvm::Value* out_indices_ptr =
2286 EmitBufferPointer(out_indices_slice, hlo->shape().tuple_shapes(1));
2287 EmitCallToFunc(
2288 runtime::kTopKF32SymbolName,
2289 {b_.getInt64(has_batch ? input->shape().dimensions(0) : 1),
2290 b_.getInt64(input->shape().dimensions().back()), b_.getInt64(k),
2291 BitCast(values_ptr, b_.getFloatTy()->getPointerTo()),
2292 BitCast(out_values_ptr, b_.getFloatTy()->getPointerTo()),
2293 BitCast(out_indices_ptr, b_.getInt32Ty()->getPointerTo())},
2294 b_.getVoidTy());
2295
2296 llvm_ir::EmitTuple(GetIrArrayFor(hlo), {out_values_ptr, out_indices_ptr},
2297 &b_);
2298 return Status::OK();
2299 }
2300
HandleCustomCall(HloInstruction * custom_call)2301 Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) {
2302 if (custom_call->custom_call_target() == "PadToStatic") {
2303 return HandlePadToStatic(custom_call);
2304 }
2305 if (custom_call->custom_call_target() == "SliceToDynamic") {
2306 return HandleSliceToDynamic(custom_call);
2307 }
2308 if (custom_call->custom_call_target() == "TopK") {
2309 return HandleTopK(custom_call);
2310 }
2311
2312 auto typed_custom_call = Cast<HloCustomCallInstruction>(custom_call);
2313 switch (typed_custom_call->api_version()) {
2314 case CustomCallApiVersion::API_VERSION_ORIGINAL:
2315 break;
2316 case CustomCallApiVersion::API_VERSION_STATUS_RETURNING:
2317 // TODO(b/194529780): Support status-returning custom calls on CPU.
2318 return Unimplemented(
2319 "XLA CPU does not support custom calls that return a success/failure "
2320 "status");
2321 default:
2322 return InternalError(
2323 "Unknown custom-call API version enum value: %d (%s)",
2324 typed_custom_call->api_version(),
2325 CustomCallApiVersion_Name(typed_custom_call->api_version()));
2326 }
2327
2328 absl::Span<HloInstruction* const> operands(custom_call->operands());
2329 llvm::Type* i8_ptr_type = b_.getInt8PtrTy();
2330 llvm::AllocaInst* operands_alloca =
2331 llvm_ir::EmitAllocaAtFunctionEntryWithCount(
2332 i8_ptr_type, b_.getInt32(operands.size()), "cc_operands_alloca", &b_);
2333 for (size_t i = 0; i < operands.size(); ++i) {
2334 const HloInstruction* operand = operands[i];
2335 llvm::Value* operand_as_i8ptr =
2336 PointerCast(GetEmittedValueFor(operand), i8_ptr_type);
2337 llvm::Value* slot_in_operands_alloca =
2338 InBoundsGEP(operands_alloca, {b_.getInt64(i)});
2339 Store(operand_as_i8ptr, slot_in_operands_alloca);
2340 }
2341 if (emit_code_for_msan_) {
2342 // Mark the alloca as initialized for msan. The buffer gets read by the
2343 // custom callee, which might be msan-instrumented.
2344 // TODO(b/66051036): Run the msan instrumentation pass instead.
2345 const llvm::DataLayout& dl = module_->getDataLayout();
2346 llvm::Type* intptr_type = b_.getIntPtrTy(dl);
2347 EmitCallToFunc(
2348 "__msan_unpoison",
2349 {PointerCast(operands_alloca, i8_ptr_type),
2350 llvm::ConstantInt::get(
2351 intptr_type, *operands_alloca->getAllocationSizeInBits(dl) / 8)},
2352 b_.getVoidTy());
2353 }
2354
2355 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(custom_call));
2356 // Write the tuple table if the output is a tuple.
2357 if (custom_call->shape().IsTuple()) {
2358 std::vector<llvm::Value*> base_ptrs;
2359 for (int i = 0; i < ShapeUtil::TupleElementCount(custom_call->shape());
2360 ++i) {
2361 const Shape& elem_shape =
2362 ShapeUtil::GetTupleElementShape(custom_call->shape(), i);
2363 TF_RET_CHECK(!elem_shape.IsTuple()) << "Nested tuples not implemented";
2364 TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice,
2365 assignment_.GetUniqueSlice(custom_call, {i}));
2366 llvm::Value* addr = EmitBufferPointer(slice, elem_shape);
2367 base_ptrs.push_back(addr);
2368 }
2369 llvm_ir::EmitTuple(GetIrArrayFor(custom_call), base_ptrs, &b_);
2370 }
2371 auto* output_address_arg =
2372 PointerCast(GetEmittedValueFor(custom_call), i8_ptr_type);
2373
2374 EmitCallToFunc(custom_call->custom_call_target(),
2375 {output_address_arg, operands_alloca}, b_.getVoidTy());
2376
2377 return Status::OK();
2378 }
2379
HandleWhile(HloInstruction * xla_while)2380 Status IrEmitter::HandleWhile(HloInstruction* xla_while) {
2381 // Precondition: Condition computation must return a scalar bool.
2382 HloComputation* condition = xla_while->while_condition();
2383 TF_RET_CHECK(ShapeUtil::IsScalar(condition->root_instruction()->shape()) &&
2384 condition->root_instruction()->shape().element_type() == PRED)
2385 << "While condition computation must return bool; got: "
2386 << ShapeUtil::HumanString(condition->root_instruction()->shape());
2387 // Check that all while-related buffers share an allocation slice.
2388 TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
2389 xla_while->shape(),
2390 [this, &xla_while](const Shape& /*subshape*/,
2391 const ShapeIndex& index) -> Status {
2392 auto check = [this](const HloInstruction* a, const HloInstruction* b,
2393 const ShapeIndex& index) {
2394 const BufferAllocation::Slice slice_a =
2395 assignment_.GetUniqueSlice(a, index).ConsumeValueOrDie();
2396 const BufferAllocation::Slice slice_b =
2397 assignment_.GetUniqueSlice(b, index).ConsumeValueOrDie();
2398 if (slice_a != slice_b) {
2399 return InternalError(
2400 "instruction %s %s does not share slice with "
2401 "instruction %s %s",
2402 a->ToString(), slice_a.ToString(), b->ToString(),
2403 slice_b.ToString());
2404 }
2405 return Status::OK();
2406 };
2407 TF_RETURN_IF_ERROR(check(xla_while, xla_while->operand(0), index));
2408 TF_RETURN_IF_ERROR(check(
2409 xla_while, xla_while->while_condition()->parameter_instruction(0),
2410 index));
2411 TF_RETURN_IF_ERROR(
2412 check(xla_while, xla_while->while_body()->parameter_instruction(0),
2413 index));
2414 TF_RETURN_IF_ERROR(check(
2415 xla_while, xla_while->while_body()->root_instruction(), index));
2416 return Status::OK();
2417 }));
2418
2419 // Set emitted value to that of 'init' with which it shares an allocation.
2420 const HloInstruction* init = xla_while->operand(0);
2421 emitted_value_[xla_while] = GetEmittedValueFor(init);
2422
2423 // Generating:
2424 // while (Condition(while_result)) {
2425 // // CopyInsertion pass inserts copies which enable 'while_result' to
2426 // // be passed back in as 'Body' parameter.
2427 // while_result = Body(while_result); // Insert
2428 // }
2429
2430 // Terminates the current block with a branch to a while header.
2431 llvm::BasicBlock* header_bb = llvm::BasicBlock::Create(
2432 module_->getContext(), IrName(xla_while, "header"),
2433 compute_function_->function());
2434 Br(header_bb);
2435 b_.SetInsertPoint(header_bb);
2436
2437 // Calls the condition function to determine whether to proceed with the
2438 // body. It must return a bool, so use the scalar call form.
2439 EmitGlobalCall(*xla_while->while_condition(), IrName(xla_while, "cond"));
2440 llvm::Value* while_predicate = ICmpNE(
2441 Load(GetBufferForGlobalCallReturnValue(*xla_while->while_condition())),
2442 llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0));
2443
2444 // Branches to the body or to the while exit depending on the condition.
2445 llvm::BasicBlock* body_bb =
2446 llvm::BasicBlock::Create(module_->getContext(), IrName(xla_while, "body"),
2447 compute_function_->function());
2448 llvm::BasicBlock* exit_bb = llvm::BasicBlock::Create(
2449 module_->getContext(), IrName(xla_while, "exit"));
2450 CondBr(while_predicate, body_bb, exit_bb);
2451
2452 // Calls the body function from the body block.
2453 b_.SetInsertPoint(body_bb);
2454
2455 // Calls the body function.
2456 EmitGlobalCall(*xla_while->while_body(), IrName(xla_while, "body"));
2457
2458 // Finishes with a branch back to the header.
2459 Br(header_bb);
2460
2461 // Adds the exit block to the function and sets the insert point there.
2462 compute_function_->function()->getBasicBlockList().push_back(exit_bb);
2463 b_.SetInsertPoint(exit_bb);
2464
2465 return Status::OK();
2466 }
2467
EmitFastConcatenate(HloInstruction * concatenate,absl::Span<HloInstruction * const> operands,string * failure_reason)2468 StatusOr<bool> IrEmitter::EmitFastConcatenate(
2469 HloInstruction* concatenate, absl::Span<HloInstruction* const> operands,
2470 string* failure_reason) {
2471 if (ShouldEmitParallelLoopFor(*concatenate)) {
2472 *failure_reason =
2473 "cannot generate memcpy-based concat for the parallel CPU backend";
2474 return false;
2475 }
2476
2477 const Shape& output_shape = concatenate->shape();
2478 for (auto* op : operands) {
2479 if (!LayoutUtil::Equal(op->shape().layout(), output_shape.layout())) {
2480 *failure_reason = "operand has mismatching layouts";
2481 return false;
2482 }
2483 }
2484
2485 // We split the dimensions into three categories: the dimension over which we
2486 // are concatenating (concat_dim), the dimensions that are minor to it
2487 // (inner_dims) and the dimensions that are major to it (outer_dims).
2488
2489 int64_t concat_dim = concatenate->dimensions(0);
2490 const Layout& output_layout = output_shape.layout();
2491 auto output_min2maj = LayoutUtil::MinorToMajor(output_layout);
2492 auto concat_dim_layout_itr = absl::c_find(output_min2maj, concat_dim);
2493
2494 std::vector<int64> inner_dims(output_min2maj.begin(), concat_dim_layout_itr);
2495 std::vector<int64> outer_dims(std::next(concat_dim_layout_itr),
2496 output_min2maj.end());
2497
2498 llvm::Type* i8_ptr_type = b_.getInt8PtrTy();
2499
2500 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(concatenate));
2501 llvm_ir::IrArray target_array = GetIrArrayFor(concatenate);
2502
2503 llvm_ir::ForLoopNest loops(IrName(concatenate), &b_);
2504 std::vector<llvm::Value*> target_multi_index =
2505 loops.AddLoopsForShapeOnDimensions(output_shape, outer_dims, "concat");
2506 std::replace(target_multi_index.begin(), target_multi_index.end(),
2507 static_cast<llvm::Value*>(nullptr),
2508 static_cast<llvm::Value*>(b_.getInt64(0)));
2509 llvm_ir::IrArray::Index target_index(target_multi_index, output_shape,
2510 b_.getInt64Ty());
2511
2512 if (!outer_dims.empty()) {
2513 SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_);
2514 }
2515
2516 PrimitiveType primitive_type = output_shape.element_type();
2517 unsigned primitive_type_size =
2518 ShapeUtil::ByteSizeOfPrimitiveType(primitive_type);
2519
2520 // Contiguous subregions from each operand to the concatenate contribute to a
2521 // contiguous subregion in the target buffer starting at target_region_begin.
2522 llvm::Value* target_region_begin = BitCast(
2523 target_array.EmitArrayElementAddress(target_index, &b_, "target_region"),
2524 i8_ptr_type);
2525 int64_t byte_offset_into_target_region = 0;
2526
2527 int64_t inner_dims_product =
2528 std::accumulate(inner_dims.begin(), inner_dims.end(), 1l,
2529 [&](int64_t product, int64_t inner_dim) {
2530 return product * output_shape.dimensions(inner_dim);
2531 });
2532
2533 // For each operand, emit a memcpy from the operand to the target of size
2534 // equal to the product of inner dimensions.
2535 for (HloInstruction* operand : operands) {
2536 const Shape& input_shape = operand->shape();
2537 llvm_ir::IrArray source_array = GetIrArrayFor(operand);
2538 llvm_ir::IrArray::Index source_index(target_multi_index, operand->shape(),
2539 b_.getInt64Ty());
2540 llvm::Value* copy_source_address = BitCast(
2541 source_array.EmitArrayElementAddress(source_index, &b_, "src_addr"),
2542 i8_ptr_type);
2543
2544 llvm::Value* copy_target_address =
2545 GEP(target_region_begin, b_.getInt64(byte_offset_into_target_region));
2546
2547 EmitTransferElements(
2548 copy_target_address, copy_source_address,
2549 inner_dims_product * input_shape.dimensions(concat_dim), primitive_type,
2550 target_array, source_array);
2551
2552 byte_offset_into_target_region += inner_dims_product *
2553 input_shape.dimensions(concat_dim) *
2554 primitive_type_size;
2555 }
2556
2557 if (!outer_dims.empty()) {
2558 SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
2559 }
2560
2561 return true;
2562 }
2563
EmitPrintf(absl::string_view fmt,absl::Span<llvm::Value * const> arguments)2564 llvm::Value* IrEmitter::EmitPrintf(absl::string_view fmt,
2565 absl::Span<llvm::Value* const> arguments) {
2566 llvm::Type* ptr_ty = b_.getInt8Ty()->getPointerTo();
2567 std::vector<llvm::Value*> call_args;
2568 call_args.push_back(b_.CreateGlobalStringPtr(llvm_ir::AsStringRef(fmt)));
2569 absl::c_copy(arguments, std::back_inserter(call_args));
2570 return b_.CreateCall(
2571 b_.GetInsertBlock()->getParent()->getParent()->getOrInsertFunction(
2572 "printf", llvm::FunctionType::get(b_.getInt32Ty(), {ptr_ty},
2573 /*isVarArg=*/true)),
2574 call_args);
2575 }
2576
EmitPrintfToStderr(absl::string_view fmt,absl::Span<llvm::Value * const> arguments)2577 llvm::Value* IrEmitter::EmitPrintfToStderr(
2578 absl::string_view fmt, absl::Span<llvm::Value* const> arguments) {
2579 llvm::Type* ptr_ty = b_.getInt8Ty()->getPointerTo();
2580 std::vector<llvm::Value*> call_args;
2581 call_args.push_back(b_.CreateGlobalStringPtr(llvm_ir::AsStringRef(fmt)));
2582 absl::c_copy(arguments, std::back_inserter(call_args));
2583 return b_.CreateCall(
2584 b_.GetInsertBlock()->getParent()->getParent()->getOrInsertFunction(
2585 runtime::kPrintfToStderrSymbolName,
2586 llvm::FunctionType::get(b_.getInt32Ty(), {ptr_ty},
2587 /*isVarArg=*/true)),
2588 call_args);
2589 }
2590
EmitCallToFunc(std::string func_name,const std::vector<llvm::Value * > & arguments,llvm::Type * return_type,bool does_not_throw,bool only_accesses_arg_memory,bool only_accesses_inaccessible_mem_or_arg_mem)2591 llvm::Value* IrEmitter::EmitCallToFunc(
2592 std::string func_name, const std::vector<llvm::Value*>& arguments,
2593 llvm::Type* return_type, bool does_not_throw, bool only_accesses_arg_memory,
2594 bool only_accesses_inaccessible_mem_or_arg_mem) {
2595 std::vector<llvm::Type*> types;
2596 types.reserve(arguments.size());
2597 absl::c_transform(arguments, std::back_inserter(types),
2598 [&](llvm::Value* val) { return val->getType(); });
2599 llvm::FunctionType* func_type =
2600 llvm::FunctionType::get(return_type, types, /*isVarArg=*/false);
2601 auto func = llvm::dyn_cast<llvm::Function>(
2602 module_->getOrInsertFunction(func_name, func_type).getCallee());
2603 func->setCallingConv(llvm::CallingConv::C);
2604 if (does_not_throw) {
2605 func->setDoesNotThrow();
2606 }
2607 if (only_accesses_arg_memory) {
2608 func->setOnlyAccessesArgMemory();
2609 }
2610 if (only_accesses_inaccessible_mem_or_arg_mem) {
2611 func->setOnlyAccessesInaccessibleMemOrArgMem();
2612 }
2613 return b_.CreateCall(func, arguments);
2614 }
2615
EmitTransferElements(llvm::Value * target,llvm::Value * source,int64_t element_count,PrimitiveType primitive_type,const llvm_ir::IrArray & target_array,const llvm_ir::IrArray & source_array)2616 void IrEmitter::EmitTransferElements(llvm::Value* target, llvm::Value* source,
2617 int64_t element_count,
2618 PrimitiveType primitive_type,
2619 const llvm_ir::IrArray& target_array,
2620 const llvm_ir::IrArray& source_array) {
2621 unsigned primitive_type_size =
2622 ShapeUtil::ByteSizeOfPrimitiveType(primitive_type);
2623 llvm::Align element_alignment(tensorflow::MathUtil::GCD<unsigned>(
2624 primitive_type_size, MinimumAlignmentForPrimitiveType(primitive_type)));
2625 llvm::Type* primitive_ptr_type = llvm::PointerType::getUnqual(
2626 llvm_ir::PrimitiveTypeToIrType(primitive_type, module_));
2627
2628 if (element_count == 1) {
2629 auto* load_instruction =
2630 AlignedLoad(BitCast(source, primitive_ptr_type), element_alignment);
2631 source_array.AnnotateLoadStoreInstructionWithMetadata(load_instruction);
2632 auto* store_instruction =
2633 AlignedStore(load_instruction, BitCast(target, primitive_ptr_type),
2634 element_alignment);
2635 target_array.AnnotateLoadStoreInstructionWithMetadata(store_instruction);
2636 } else {
2637 auto* memcpy_instruction = b_.CreateMemCpy(
2638 target, /*DstAlign=*/llvm::Align(element_alignment), source,
2639 /*SrcAlign=*/llvm::Align(element_alignment),
2640 element_count * primitive_type_size);
2641
2642 // The memcpy does the load and the store internally. The aliasing related
2643 // metadata has to reflect that.
2644 std::map<int, llvm::MDNode*> merged_metadata =
2645 llvm_ir::MergeMetadata(&module_->getContext(), source_array.metadata(),
2646 target_array.metadata());
2647 for (const auto& kind_md_pair : merged_metadata) {
2648 memcpy_instruction->setMetadata(kind_md_pair.first, kind_md_pair.second);
2649 }
2650 }
2651 }
2652
HandleConcatenate(HloInstruction * concatenate)2653 Status IrEmitter::HandleConcatenate(HloInstruction* concatenate) {
2654 absl::Span<HloInstruction* const> operands(concatenate->operands());
2655 string failure_reason;
2656 TF_ASSIGN_OR_RETURN(
2657 bool successful,
2658 EmitFastConcatenate(concatenate, operands, &failure_reason));
2659 if (successful) {
2660 VLOG(1) << "Emitted fast concatenate for " << concatenate->ToString();
2661 return Status::OK();
2662 }
2663
2664 VLOG(1) << "Could not emit fast concatenate for " << concatenate->ToString()
2665 << ": " << failure_reason;
2666
2667 return DefaultAction(concatenate);
2668 }
2669
HandleConditional(HloInstruction * conditional)2670 Status IrEmitter::HandleConditional(HloInstruction* conditional) {
2671 auto branch_index = conditional->operand(0);
2672 int num_branches = conditional->branch_count();
2673 TF_RET_CHECK(ShapeUtil::IsScalar(branch_index->shape()) &&
2674 (branch_index->shape().element_type() == PRED ||
2675 branch_index->shape().element_type() == S32))
2676 << "Branch index on a conditional must be scalar bool or int32; got: "
2677 << ShapeUtil::HumanString(branch_index->shape());
2678
2679 for (int b = 0; b < num_branches; ++b) {
2680 HloComputation* br_computation = conditional->branch_computation(b);
2681 TF_RET_CHECK(ShapeUtil::Equal(conditional->shape(),
2682 br_computation->root_instruction()->shape()))
2683 << "Shape of conditional should be same as the shape of the " << b
2684 << "th branch computation; got: "
2685 << ShapeUtil::HumanString(conditional->shape()) << " and "
2686 << ShapeUtil::HumanString(br_computation->root_instruction()->shape());
2687 }
2688
2689 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(conditional));
2690
2691 if (branch_index->shape().element_type() == PRED) {
2692 // Emit an if-else to LLVM:
2693 // if (pred)
2694 // cond_result = true_computation(true_operand)
2695 // else
2696 // cond_result = false_computation(false_operand)
2697 llvm::LoadInst* pred_value = Load(
2698 GetIrArrayFor(branch_index).GetBasePointer(), "load_predicate_value");
2699 llvm::Value* pred_cond =
2700 ICmpNE(pred_value,
2701 llvm::ConstantInt::get(
2702 llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0),
2703 "boolean_predicate");
2704 llvm_ir::LlvmIfData if_data =
2705 llvm_ir::EmitIfThenElse(pred_cond, "conditional", &b_);
2706
2707 SetToFirstInsertPoint(if_data.true_block, &b_);
2708 EmitGlobalCall(*conditional->branch_computation(0),
2709 IrName(conditional, "_true"));
2710
2711 SetToFirstInsertPoint(if_data.false_block, &b_);
2712 EmitGlobalCall(*conditional->branch_computation(1),
2713 IrName(conditional, "_false"));
2714
2715 SetToFirstInsertPoint(if_data.after_block, &b_);
2716 return Status::OK();
2717 }
2718 // We emit a switch statement to LLVM:
2719 // switch (branch_index) {
2720 // default:
2721 // result = branch_computations[num_branches-1](operands[num_branches-1]);
2722 // break;
2723 // case 0:
2724 // result = branch_computations[0](operands[0]); break;
2725 // case 1:
2726 // result = branch_computations[1](operands[1]); break;
2727 // ...
2728 // case [[num_branches-2]]:
2729 // result = branch_computations[num_branches-2](operands[num_branches-2]);
2730 // break;
2731 // }
2732 llvm::LoadInst* branch_index_value = Load(
2733 GetIrArrayFor(branch_index).GetBasePointer(), "load_branch_index_value");
2734
2735 auto case_block = b_.GetInsertBlock();
2736 llvm::BasicBlock* after_block;
2737 // Add a terminator to the case block, if necessary.
2738 if (case_block->getTerminator() == nullptr) {
2739 after_block = llvm_ir::CreateBasicBlock(nullptr, "case-after", &b_);
2740 b_.SetInsertPoint(case_block);
2741 b_.CreateBr(after_block);
2742 } else {
2743 after_block =
2744 case_block->splitBasicBlock(b_.GetInsertPoint(), "case-after");
2745 }
2746 // Our basic block should now end with an unconditional branch. Remove it;
2747 // we're going to replace it with a switch based branch.
2748 case_block->getTerminator()->eraseFromParent();
2749
2750 // Lower the default branch computation.
2751 auto default_block = llvm_ir::CreateBasicBlock(nullptr, "case-default", &b_);
2752 b_.SetInsertPoint(default_block);
2753 EmitGlobalCall(*conditional->branch_computation(num_branches - 1),
2754 IrName(conditional, "_default"));
2755 b_.CreateBr(after_block);
2756
2757 // Prepare the switch (branch_index) { ... } instruction.
2758 b_.SetInsertPoint(case_block);
2759 llvm::SwitchInst* case_inst =
2760 b_.CreateSwitch(branch_index_value, default_block, num_branches - 1);
2761 // Lower each branch's computation.
2762 for (int b = 0; b < num_branches - 1; ++b) { // last branch is default
2763 // Lower the case b: { ... ; break; } computation.
2764 auto branch_block =
2765 llvm_ir::CreateBasicBlock(nullptr, absl::StrCat("case-branch", b), &b_);
2766 b_.SetInsertPoint(branch_block);
2767 EmitGlobalCall(*conditional->branch_computation(b),
2768 IrName(conditional, absl::StrCat("_branch", b)));
2769 b_.CreateBr(after_block);
2770 case_inst->addCase(b_.getInt32(b), branch_block);
2771 }
2772
2773 SetToFirstInsertPoint(after_block, &b_);
2774 return Status::OK();
2775 }
2776
HandleAfterAll(HloInstruction * after_all)2777 Status IrEmitter::HandleAfterAll(HloInstruction* after_all) {
2778 TF_RET_CHECK(ByteSizeOf(after_all->shape()) == 0);
2779 // No code to generate, but we need to emit an address for book-keeping.
2780 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(after_all));
2781 return Status::OK();
2782 }
2783
HandleAddDependency(HloInstruction * add_dependency)2784 Status IrEmitter::HandleAddDependency(HloInstruction* add_dependency) {
2785 // AddDedendency just forwards its zero-th operand.
2786 emitted_value_[add_dependency] =
2787 GetEmittedValueFor(add_dependency->operand(0));
2788 return Status::OK();
2789 }
2790
HandleRng(HloInstruction * rng)2791 Status IrEmitter::HandleRng(HloInstruction* rng) {
2792 return Unimplemented("Rng should be expanded for CPU.");
2793 }
2794
HandleRngGetAndUpdateState(HloInstruction * rng_state)2795 Status IrEmitter::HandleRngGetAndUpdateState(HloInstruction* rng_state) {
2796 VLOG(2) << "RngGetAndUpdateState: " << rng_state->ToString();
2797 llvm::Value* old_state = llvm_ir::RngGetAndUpdateState(
2798 Cast<HloRngGetAndUpdateStateInstruction>(rng_state)->delta(), module_,
2799 &b_);
2800
2801 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(rng_state));
2802 llvm::Value* address = GetEmittedValueFor(rng_state);
2803
2804 // The buffer has an array type while the value has a i128. Cast the
2805 // buffer to i128 type to store the value.
2806 address = BitCast(address, llvm::PointerType::get(
2807 old_state->getType()->getScalarType(),
2808 address->getType()->getPointerAddressSpace()));
2809 llvm::StoreInst* store = Store(old_state, address);
2810 store->setAlignment(llvm::Align(IrEmitter::MinimumAlignmentForPrimitiveType(
2811 rng_state->shape().element_type())));
2812
2813 return Status::OK();
2814 }
2815
FinishVisit(HloInstruction * root)2816 Status IrEmitter::FinishVisit(HloInstruction* root) {
2817 // When this method is called, we should have already emitted an IR value for
2818 // the root (return) op. The IR value holds the address of the buffer holding
2819 // the value. If the root is a constant or parameter, we perform a memcpy from
2820 // this buffer to the retval buffer of the computation. Otherwise, there's
2821 // nothing to do since the result was already written directly into the output
2822 // buffer.
2823 VLOG(2) << "FinishVisit root: " << root->ToString();
2824 if (root->opcode() == HloOpcode::kOutfeed) {
2825 VLOG(2) << " outfeed with value: "
2826 << llvm_ir::DumpToString(*GetEmittedValueFor(root->operand(0)));
2827 } else {
2828 VLOG(2) << " value: " << llvm_ir::DumpToString(*GetEmittedValueFor(root));
2829 }
2830
2831 auto record_complete_computation = [&](llvm::Value* prof_counter) {
2832 if (prof_counter) {
2833 profiling_state_.RecordCompleteComputation(&b_, prof_counter);
2834 }
2835 };
2836
2837 // For the entry computation this increment is cumulative of embedded
2838 // computations since it includes cycles spent in computations invoked by
2839 // While, Call etc.
2840 record_complete_computation(GetProfileCounterFor(*root->parent()));
2841 return Status::OK();
2842 }
2843
2844 template <typename T>
GetProfileCounterCommon(const T & hlo,const std::unordered_map<const T *,int64> & profile_index_map)2845 llvm::Value* IrEmitter::GetProfileCounterCommon(
2846 const T& hlo,
2847 const std::unordered_map<const T*, int64>& profile_index_map) {
2848 auto it = profile_index_map.find(&hlo);
2849 if (it == profile_index_map.end()) {
2850 return nullptr;
2851 }
2852
2853 int64_t prof_counter_idx = it->second;
2854 string counter_name = IrName("prof_counter", hlo.name());
2855 return GEP(GetProfileCountersArgument(), b_.getInt64(prof_counter_idx),
2856 counter_name);
2857 }
2858
GetProfileCounterFor(const HloInstruction & instruction)2859 llvm::Value* IrEmitter::GetProfileCounterFor(
2860 const HloInstruction& instruction) {
2861 return GetProfileCounterCommon<HloInstruction>(instruction,
2862 instruction_to_profile_idx_);
2863 }
2864
GetProfileCounterFor(const HloComputation & computation)2865 llvm::Value* IrEmitter::GetProfileCounterFor(
2866 const HloComputation& computation) {
2867 return GetProfileCounterCommon<HloComputation>(computation,
2868 computation_to_profile_idx_);
2869 }
2870
UpdateProfileCounter(llvm::IRBuilder<> * b,llvm::Value * prof_counter,llvm::Value * cycle_end,llvm::Value * cycle_start)2871 void IrEmitter::ProfilingState::UpdateProfileCounter(llvm::IRBuilder<>* b,
2872 llvm::Value* prof_counter,
2873 llvm::Value* cycle_end,
2874 llvm::Value* cycle_start) {
2875 auto* cycle_diff = b->CreateSub(cycle_end, cycle_start);
2876 llvm::LoadInst* old_cycle_count =
2877 b->CreateLoad(prof_counter, "old_cycle_count");
2878 auto* new_cycle_count =
2879 b->CreateAdd(cycle_diff, old_cycle_count, "new_cycle_count");
2880 b->CreateStore(new_cycle_count, prof_counter);
2881 }
2882
ReadCycleCounter(llvm::IRBuilder<> * b)2883 llvm::Value* IrEmitter::ProfilingState::ReadCycleCounter(llvm::IRBuilder<>* b) {
2884 llvm::Module* module = b->GetInsertBlock()->getModule();
2885 if (!use_rdtscp_) {
2886 llvm::Function* func_llvm_readcyclecounter =
2887 llvm::Intrinsic::getDeclaration(module,
2888 llvm::Intrinsic::readcyclecounter);
2889 return b->CreateCall(func_llvm_readcyclecounter);
2890 }
2891 llvm::Function* func_llvm_x86_rdtscp =
2892 llvm::Intrinsic::getDeclaration(module, llvm::Intrinsic::x86_rdtscp);
2893 llvm::Value* rdtscp_call = b->CreateCall(func_llvm_x86_rdtscp);
2894 return b->CreateExtractValue(rdtscp_call, {0});
2895 }
2896
RecordCycleStart(llvm::IRBuilder<> * b,HloInstruction * hlo)2897 void IrEmitter::ProfilingState::RecordCycleStart(llvm::IRBuilder<>* b,
2898 HloInstruction* hlo) {
2899 auto* cycle_start = ReadCycleCounter(b);
2900 cycle_start->setName(IrName(hlo, "cycle_start"));
2901 cycle_starts_[hlo] = cycle_start;
2902 if (first_read_cycle_start_ == nullptr) {
2903 first_read_cycle_start_ = cycle_start;
2904 }
2905 }
2906
RecordCycleDelta(llvm::IRBuilder<> * b,HloInstruction * hlo,llvm::Value * prof_counter)2907 void IrEmitter::ProfilingState::RecordCycleDelta(llvm::IRBuilder<>* b,
2908 HloInstruction* hlo,
2909 llvm::Value* prof_counter) {
2910 auto* cycle_end = ReadCycleCounter(b);
2911 cycle_end->setName(IrName(hlo, "cycle_end"));
2912 auto* cycle_start = cycle_starts_[hlo];
2913 UpdateProfileCounter(b, prof_counter, cycle_end, cycle_start);
2914 last_read_cycle_end_ = cycle_end;
2915 }
2916
RecordCompleteComputation(llvm::IRBuilder<> * b,llvm::Value * prof_counter)2917 void IrEmitter::ProfilingState::RecordCompleteComputation(
2918 llvm::IRBuilder<>* b, llvm::Value* prof_counter) {
2919 if (last_read_cycle_end_ && first_read_cycle_start_) {
2920 UpdateProfileCounter(b, prof_counter, last_read_cycle_end_,
2921 first_read_cycle_start_);
2922 }
2923 }
2924
EmitTracingStart(llvm::IRBuilder<> * b,HloInstruction * hlo,llvm::Value * run_options)2925 void IrEmitter::TracingState::EmitTracingStart(llvm::IRBuilder<>* b,
2926 HloInstruction* hlo,
2927 llvm::Value* run_options) {
2928 if (!enabled_) {
2929 return;
2930 }
2931
2932 llvm::Type* int8_ptr_type = b->getInt8Ty()->getPointerTo();
2933 llvm::Type* void_ptr_type =
2934 int8_ptr_type; // LLVM does not have a void*, we use an int8* instead.
2935 llvm::FunctionType* fn_type =
2936 llvm::FunctionType::get(b->getInt64Ty(), {void_ptr_type, int8_ptr_type},
2937 /*isVarArg=*/false);
2938
2939 llvm::Function* function = b->GetInsertBlock()->getParent();
2940 llvm::Module* module = function->getParent();
2941 const char* fn_name = runtime::kTracingStartSymbolName;
2942 llvm::FunctionCallee trace_func =
2943 module->getOrInsertFunction(fn_name, fn_type);
2944 if (auto* fn = llvm::dyn_cast<llvm::Function>(trace_func.getCallee())) {
2945 fn->setCallingConv(llvm::CallingConv::C);
2946 fn->setDoesNotThrow();
2947 fn->setOnlyAccessesArgMemory();
2948 }
2949 auto* hlo_name = b->CreateGlobalStringPtr(hlo->name());
2950 auto* activity_id =
2951 b->CreateCall(trace_func, {b->CreateBitCast(run_options, void_ptr_type),
2952 b->CreateBitCast(hlo_name, int8_ptr_type)});
2953 activity_id->setName(IrName(hlo, "activity_id"));
2954 activity_ids_[hlo] = activity_id;
2955 }
2956
EmitTracingEnd(llvm::IRBuilder<> * b,HloInstruction * hlo,llvm::Value * run_options)2957 void IrEmitter::TracingState::EmitTracingEnd(llvm::IRBuilder<>* b,
2958 HloInstruction* hlo,
2959 llvm::Value* run_options) {
2960 if (!enabled_) {
2961 return;
2962 }
2963
2964 llvm::Type* void_ptr_type =
2965 b->getInt8Ty()->getPointerTo(); // LLVM does not have a void*, we use an
2966 // int8* instead.
2967 llvm::FunctionType* fn_type =
2968 llvm::FunctionType::get(b->getVoidTy(), {void_ptr_type, b->getInt64Ty()},
2969 /*isVarArg=*/false);
2970
2971 llvm::Function* function = b->GetInsertBlock()->getParent();
2972 llvm::Module* module = function->getParent();
2973 const char* fn_name = runtime::kTracingEndSymbolName;
2974 llvm::FunctionCallee trace_func =
2975 module->getOrInsertFunction(fn_name, fn_type);
2976 if (auto* fn = llvm::dyn_cast<llvm::Function>(trace_func.getCallee())) {
2977 fn->setCallingConv(llvm::CallingConv::C);
2978 fn->setDoesNotThrow();
2979 fn->setOnlyAccessesArgMemory();
2980 }
2981 auto* activity_id = activity_ids_.at(hlo);
2982 b->CreateCall(trace_func,
2983 {b->CreateBitCast(run_options, void_ptr_type), activity_id});
2984 }
2985
2986 namespace {
IsHloVeryCheap(const HloInstruction * hlo)2987 bool IsHloVeryCheap(const HloInstruction* hlo) {
2988 return hlo->opcode() == HloOpcode::kBitcast ||
2989 hlo->opcode() == HloOpcode::kTuple ||
2990 hlo->opcode() == HloOpcode::kGetTupleElement ||
2991 hlo->opcode() == HloOpcode::kParameter ||
2992 hlo->opcode() == HloOpcode::kConstant;
2993 }
2994 } // namespace
2995
Preprocess(HloInstruction * hlo)2996 Status IrEmitter::Preprocess(HloInstruction* hlo) {
2997 VLOG(3) << "Visiting: " << hlo->ToString();
2998 // When profiling is enabled, trace the same HLOs that the profiler does.
2999 if (instruction_to_profile_idx_.count(hlo) ||
3000 (hlo_module_config_.cpu_traceme_enabled() && !IsHloVeryCheap(hlo))) {
3001 tracing_state_.EmitTracingStart(&b_, hlo,
3002 GetExecutableRunOptionsArgument());
3003 profiling_state_.RecordCycleStart(&b_, hlo);
3004 }
3005 return Status::OK();
3006 }
3007
Postprocess(HloInstruction * hlo)3008 Status IrEmitter::Postprocess(HloInstruction* hlo) {
3009 if (auto* prof_counter = GetProfileCounterFor(*hlo)) {
3010 profiling_state_.RecordCycleDelta(&b_, hlo, prof_counter);
3011 }
3012 // When profiling is enabled, trace the same HLOs that the profiler does.
3013 if (instruction_to_profile_idx_.count(hlo) ||
3014 (hlo_module_config_.cpu_traceme_enabled() && !IsHloVeryCheap(hlo))) {
3015 tracing_state_.EmitTracingEnd(&b_, hlo, GetExecutableRunOptionsArgument());
3016 }
3017 return Status::OK();
3018 }
3019
GetIrArrayFor(const HloInstruction * hlo)3020 llvm_ir::IrArray IrEmitter::GetIrArrayFor(const HloInstruction* hlo) {
3021 llvm::Value* value_for_op = GetEmittedValueFor(hlo);
3022
3023 llvm_ir::IrArray array(value_for_op, hlo->shape());
3024 AddAliasingInformationToIrArray(*hlo, &array);
3025 return array;
3026 }
3027
GetIrArraysForOperandsOf(const HloInstruction * hlo)3028 std::vector<llvm_ir::IrArray> IrEmitter::GetIrArraysForOperandsOf(
3029 const HloInstruction* hlo) {
3030 std::vector<llvm_ir::IrArray> arrays;
3031 std::transform(
3032 hlo->operands().begin(), hlo->operands().end(),
3033 std::back_inserter(arrays),
3034 [&](const HloInstruction* operand) { return GetIrArrayFor(operand); });
3035 return arrays;
3036 }
3037
GetEmittedValueFor(const HloInstruction * hlo)3038 llvm::Value* IrEmitter::GetEmittedValueFor(const HloInstruction* hlo) {
3039 auto it = emitted_value_.find(hlo);
3040 if (it == emitted_value_.end()) {
3041 LOG(FATAL) << "could not find emitted value for: " << hlo->ToString();
3042 }
3043 return it->second;
3044 }
3045
IrShapeType(const Shape & shape)3046 llvm::Type* IrEmitter::IrShapeType(const Shape& shape) {
3047 return llvm_ir::ShapeToIrType(shape, module_);
3048 }
3049
GetProfileCountersArgument()3050 llvm::Value* IrEmitter::GetProfileCountersArgument() {
3051 return compute_function_->profile_counters_arg();
3052 }
3053
GetBufferTableArgument()3054 llvm::Value* IrEmitter::GetBufferTableArgument() {
3055 return compute_function_->buffer_table_arg();
3056 }
3057
GetExecutableRunOptionsArgument()3058 llvm::Value* IrEmitter::GetExecutableRunOptionsArgument() {
3059 return compute_function_->exec_run_options_arg();
3060 }
3061
EmitThreadLocalBufferPointer(const BufferAllocation::Slice & slice,const Shape & target_shape)3062 llvm::Value* IrEmitter::EmitThreadLocalBufferPointer(
3063 const BufferAllocation::Slice& slice, const Shape& target_shape) {
3064 const BufferAllocation& allocation = *slice.allocation();
3065 llvm::Value* tempbuf_address = [&]() -> llvm::Value* {
3066 auto param_it =
3067 computation_parameter_allocations_.find(slice.allocation()->index());
3068 if (param_it != computation_parameter_allocations_.end()) {
3069 int64_t param_number = param_it->second;
3070 // We have to access the parameter at offset param_number in the params
3071 // array. The code generated here is equivalent to this C code:
3072 //
3073 // i8* param_address_untyped = params[param_number];
3074 // Param* param_address_typed = (Param*)param_address_untyped;
3075 //
3076 // Where Param is the actual element type of the underlying buffer (for
3077 // example, float for an XLA F32 element type).
3078 llvm::Value* params = compute_function_->parameters_arg();
3079 llvm::Value* param_address_offset =
3080 llvm_ir::EmitBufferIndexingGEP(params, param_number, &b_);
3081 llvm::LoadInst* param_address_untyped = Load(param_address_offset);
3082
3083 if (!target_shape.IsOpaque()) {
3084 AttachAlignmentMetadataForLoad(param_address_untyped, target_shape);
3085 AttachDereferenceableMetadataForLoad(param_address_untyped,
3086 target_shape);
3087 }
3088 return param_address_untyped;
3089 }
3090
3091 // Thread-local allocations should only be assigned a single buffer.
3092 const auto& assigned_buffers = allocation.assigned_buffers();
3093 CHECK_EQ(1, assigned_buffers.size());
3094 const Shape& shape = assigned_buffers.begin()->first->shape();
3095
3096 std::pair<llvm::Function*, BufferAllocation::Slice> key = {
3097 compute_function_->function(), slice};
3098 auto buf_it = thread_local_buffers_.find(key);
3099 if (buf_it == thread_local_buffers_.end()) {
3100 llvm::Value* buffer = llvm_ir::EmitAllocaAtFunctionEntry(
3101 IrShapeType(shape), absl::StrCat("thread_local", slice.ToString()),
3102 &b_, MinimumAlignmentForShape(target_shape));
3103 auto it_inserted_pair = thread_local_buffers_.insert({key, buffer});
3104 CHECK(it_inserted_pair.second);
3105 buf_it = it_inserted_pair.first;
3106 }
3107 return buf_it->second;
3108 }();
3109 return BitCast(tempbuf_address, IrShapeType(target_shape)->getPointerTo());
3110 }
3111
EmitGlobalBufferPointer(const BufferAllocation::Slice & slice,const Shape & target_shape)3112 llvm::Value* IrEmitter::EmitGlobalBufferPointer(
3113 const BufferAllocation::Slice& slice, const Shape& target_shape) {
3114 const BufferAllocation& allocation = *slice.allocation();
3115 llvm::Value* tempbuf_address_ptr = llvm_ir::EmitBufferIndexingGEP(
3116 GetBufferTableArgument(), slice.index(), &b_);
3117 llvm::LoadInst* tempbuf_address_base = Load(tempbuf_address_ptr);
3118 if (hlo_module_config_.debug_options()
3119 .xla_llvm_enable_invariant_load_metadata()) {
3120 tempbuf_address_base->setMetadata(
3121 llvm::LLVMContext::MD_invariant_load,
3122 llvm::MDNode::get(tempbuf_address_base->getContext(), /*MDs=*/{}));
3123 }
3124 AttachAlignmentMetadataForLoad(tempbuf_address_base, allocation.size());
3125 AttachDereferenceableMetadataForLoad(tempbuf_address_base, allocation.size());
3126
3127 llvm::Value* tempbuf_address_untyped = tempbuf_address_base;
3128 if (slice.offset() > 0) {
3129 // Adjust the address to account for the slice offset.
3130 tempbuf_address_untyped =
3131 InBoundsGEP(tempbuf_address_base, b_.getInt64(slice.offset()));
3132 }
3133 return BitCast(tempbuf_address_untyped,
3134 IrShapeType(target_shape)->getPointerTo());
3135 }
3136
EmitBufferPointer(const BufferAllocation::Slice & slice,const Shape & target_shape)3137 llvm::Value* IrEmitter::EmitBufferPointer(const BufferAllocation::Slice& slice,
3138 const Shape& target_shape) {
3139 if (slice.allocation()->is_thread_local()) {
3140 return EmitThreadLocalBufferPointer(slice, target_shape);
3141 } else if (slice.allocation()->is_constant()) {
3142 return BitCast(
3143 FindOrDie(constant_buffer_to_global_, slice.allocation()->index()),
3144 IrShapeType(target_shape)->getPointerTo());
3145 } else {
3146 return EmitGlobalBufferPointer(slice, target_shape);
3147 }
3148 }
3149
EmitTargetAddressForOp(const HloInstruction * op)3150 Status IrEmitter::EmitTargetAddressForOp(const HloInstruction* op) {
3151 const Shape& target_shape = op->shape();
3152 TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice,
3153 assignment_.GetUniqueTopLevelSlice(op));
3154 llvm::Value* addr = EmitBufferPointer(slice, target_shape);
3155 addr->setName(IrName(op));
3156 emitted_value_[op] = addr;
3157 return Status::OK();
3158 }
3159
EmitTargetElementLoop(HloInstruction * target_op,const llvm_ir::ElementGenerator & element_generator)3160 Status IrEmitter::EmitTargetElementLoop(
3161 HloInstruction* target_op,
3162 const llvm_ir::ElementGenerator& element_generator) {
3163 return EmitTargetElementLoop(target_op, /*desc=*/"", element_generator);
3164 }
3165
EmitTargetElementLoop(HloInstruction * target_op,absl::string_view desc,const llvm_ir::ElementGenerator & element_generator)3166 Status IrEmitter::EmitTargetElementLoop(
3167 HloInstruction* target_op, absl::string_view desc,
3168 const llvm_ir::ElementGenerator& element_generator) {
3169 VLOG(2) << "EmitTargetElementLoop: " << target_op->ToString();
3170
3171 const Shape& target_shape = target_op->shape();
3172 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(target_op));
3173 llvm_ir::IrArray target_array = GetIrArrayFor(target_op);
3174
3175 if (target_shape.IsTuple() &&
3176 (target_op->opcode() == HloOpcode::kFusion ||
3177 target_op->opcode() == HloOpcode::kReduce ||
3178 target_op->opcode() == HloOpcode::kReduceWindow)) {
3179 // For multiple outputs fusion, we need to emit each operand and the root.
3180 TF_RET_CHECK(num_dynamic_loop_bounds_ == 0);
3181 std::vector<llvm_ir::IrArray> output_arrays;
3182 for (int64_t i = 0; i < ShapeUtil::TupleElementCount(target_shape); ++i) {
3183 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice,
3184 assignment_.GetUniqueSlice(target_op, {i}));
3185 const Shape& element_shape = ShapeUtil::GetSubshape(target_shape, {i});
3186 llvm::Value* op_target_address = EmitBufferPointer(slice, element_shape);
3187 output_arrays.push_back(
3188 llvm_ir::IrArray(op_target_address, element_shape));
3189 }
3190 TF_RETURN_IF_ERROR(
3191 llvm_ir::LoopEmitter(element_generator, output_arrays, &b_)
3192 .EmitLoop(IrName(target_op)));
3193
3194 std::vector<llvm::Value*> tuple_operand_ptrs;
3195 for (int64_t i = 0; i < output_arrays.size(); ++i) {
3196 tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer());
3197 }
3198 llvm_ir::EmitTuple(target_array, tuple_operand_ptrs, &b_);
3199
3200 } else {
3201 if (ShouldEmitParallelLoopFor(*target_op)) {
3202 // Emit code to read dynamic loop bounds from compute function argument.
3203 std::vector<std::pair<llvm::Value*, llvm::Value*>> dynamic_loop_bounds =
3204 compute_function_->GetDynamicLoopBounds();
3205 // Emit parallel loop with dynamic loop bounds for most-major dimensions.
3206 TF_RETURN_IF_ERROR(ParallelLoopEmitter(element_generator, target_array,
3207 &dynamic_loop_bounds, &b_)
3208 .EmitLoop(IrName(target_op)));
3209 } else {
3210 TF_RETURN_IF_ERROR(
3211 llvm_ir::LoopEmitter(element_generator, target_array, &b_)
3212 .EmitLoop(IrName(target_op)));
3213 }
3214 }
3215 return Status::OK();
3216 }
3217
EmitMemcpy(const HloInstruction & source,const HloInstruction & destination)3218 Status IrEmitter::EmitMemcpy(const HloInstruction& source,
3219 const HloInstruction& destination) {
3220 llvm::Value* source_value = GetEmittedValueFor(&source);
3221 llvm::Value* destination_value = GetEmittedValueFor(&destination);
3222 int64_t source_size = ByteSizeOf(source.shape());
3223 // TODO(b/63762267): Be more aggressive about specifying alignment.
3224 MemCpy(destination_value, /*DstAlign=*/llvm::Align(1), source_value,
3225 /*SrcAlign=*/llvm::Align(1), source_size);
3226 return Status::OK();
3227 }
3228
ElementTypesSameAndSupported(const HloInstruction & instruction,absl::Span<const HloInstruction * const> operands,absl::Span<const PrimitiveType> supported_types)3229 Status IrEmitter::ElementTypesSameAndSupported(
3230 const HloInstruction& instruction,
3231 absl::Span<const HloInstruction* const> operands,
3232 absl::Span<const PrimitiveType> supported_types) {
3233 for (auto operand : operands) {
3234 TF_RET_CHECK(
3235 ShapeUtil::SameElementType(operands[0]->shape(), operand->shape()));
3236 }
3237
3238 TF_RET_CHECK(!operands.empty());
3239 PrimitiveType primitive_type = operands[0]->shape().element_type();
3240 if (!absl::c_linear_search(supported_types, primitive_type)) {
3241 return Unimplemented("unsupported operand type %s in op %s",
3242 PrimitiveType_Name(primitive_type),
3243 HloOpcodeString(instruction.opcode()));
3244 }
3245 return Status::OK();
3246 }
3247
DefaultAction(HloInstruction * hlo)3248 Status IrEmitter::DefaultAction(HloInstruction* hlo) {
3249 ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator;
3250 for (const HloInstruction* operand : hlo->operands()) {
3251 operand_to_generator[operand] = [=](const llvm_ir::IrArray::Index& index) {
3252 return GetIrArrayFor(operand).EmitReadArrayElement(index, &b_);
3253 };
3254 }
3255 CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_);
3256 return EmitTargetElementLoop(
3257 hlo, elemental_emitter.MakeElementGenerator(hlo, operand_to_generator));
3258 }
3259
EmitScalarReturningThreadLocalCall(const HloComputation & callee,absl::Span<llvm::Value * const> parameters,absl::string_view name)3260 llvm::Value* IrEmitter::EmitScalarReturningThreadLocalCall(
3261 const HloComputation& callee, absl::Span<llvm::Value* const> parameters,
3262 absl::string_view name) {
3263 std::vector<llvm::Value*> return_value =
3264 EmitThreadLocalCall(callee, parameters, name);
3265 CHECK_EQ(return_value.size(), 1);
3266 return return_value[0];
3267 }
3268
EmitThreadLocalCall(const HloComputation & callee,absl::Span<llvm::Value * const> parameters,absl::string_view name)3269 std::vector<llvm::Value*> IrEmitter::EmitThreadLocalCall(
3270 const HloComputation& callee, absl::Span<llvm::Value* const> parameters,
3271 absl::string_view name) {
3272 CHECK(absl::c_binary_search(thread_local_computations_, &callee));
3273 const Shape& return_shape = callee.root_instruction()->shape();
3274 bool is_scalar_return = ShapeUtil::IsScalar(return_shape);
3275 bool is_tuple_of_scalars_return =
3276 return_shape.IsTuple() &&
3277 absl::c_all_of(return_shape.tuple_shapes(), [&](const Shape& shape) {
3278 return ShapeUtil::IsScalar(shape);
3279 });
3280 CHECK(is_scalar_return || is_tuple_of_scalars_return);
3281
3282 std::vector<llvm::Value*> parameter_addrs;
3283 for (llvm::Value* parameter : parameters) {
3284 CHECK(!parameter->getType()->isPointerTy());
3285 llvm::Value* parameter_addr = llvm_ir::EmitAllocaAtFunctionEntry(
3286 parameter->getType(), "arg_addr", &b_);
3287 Store(parameter, parameter_addr);
3288 parameter_addrs.push_back(parameter_addr);
3289 }
3290
3291 llvm::Type* return_value_buffer_type =
3292 llvm_ir::ShapeToIrType(return_shape, module_);
3293 std::string retval_alloca_name = absl::StrCat(name, "_return_value_addr");
3294 int retval_alignment =
3295 is_scalar_return
3296 ? MinimumAlignmentForPrimitiveType(return_shape.element_type())
3297 : 0;
3298 llvm::Value* return_value_buffer = llvm_ir::EmitAllocaAtFunctionEntry(
3299 return_value_buffer_type, retval_alloca_name, &b_, retval_alignment);
3300
3301 std::vector<llvm::Value*> allocas_for_returned_scalars;
3302 if (is_scalar_return) {
3303 allocas_for_returned_scalars.push_back(return_value_buffer);
3304 } else {
3305 constexpr int max_tuple_size = 1000;
3306 CHECK_LT(return_shape.tuple_shapes_size(), max_tuple_size)
3307 << "Multivalue function can not return more than 1000 elements to avoid"
3308 << " stack smashing";
3309 allocas_for_returned_scalars =
3310 llvm_ir::EmitTupleAllocasAtFunctionEntry(return_shape, &b_);
3311 llvm_ir::IrArray tuple_array(return_value_buffer, return_shape);
3312
3313 EmitTuple(tuple_array, allocas_for_returned_scalars, &b_);
3314 }
3315
3316 Call(FindOrDie(emitted_functions_, &callee),
3317 GetArrayFunctionCallArguments(
3318 parameter_addrs, &b_, name,
3319 /*return_value_buffer=*/return_value_buffer,
3320 /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(),
3321 /*buffer_table_arg=*/
3322 llvm::Constant::getNullValue(b_.getInt8PtrTy()->getPointerTo()),
3323 /*profile_counters_arg=*/GetProfileCountersArgument()));
3324
3325 std::vector<llvm::Value*> returned_scalars;
3326 returned_scalars.reserve(allocas_for_returned_scalars.size());
3327 for (llvm::Value* addr : allocas_for_returned_scalars) {
3328 returned_scalars.push_back(Load(addr));
3329 }
3330 return returned_scalars;
3331 }
3332
EmitGlobalCall(const HloComputation & callee,absl::string_view name)3333 void IrEmitter::EmitGlobalCall(const HloComputation& callee,
3334 absl::string_view name) {
3335 CHECK(absl::c_binary_search(global_computations_, &callee));
3336
3337 Call(FindOrDie(emitted_functions_, &callee),
3338 GetArrayFunctionCallArguments(
3339 /*parameter_addresses=*/{}, &b_, name,
3340 /*return_value_buffer=*/
3341 llvm::Constant::getNullValue(b_.getInt8PtrTy()),
3342 /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(),
3343 /*buffer_table_arg=*/GetBufferTableArgument(),
3344 /*profile_counters_arg=*/GetProfileCountersArgument()));
3345 }
3346
GetBufferForGlobalCallReturnValue(const HloComputation & callee)3347 llvm::Value* IrEmitter::GetBufferForGlobalCallReturnValue(
3348 const HloComputation& callee) {
3349 const HloInstruction* root_inst = callee.root_instruction();
3350 if (root_inst->opcode() == HloOpcode::kOutfeed) {
3351 return llvm::Constant::getNullValue(b_.getInt8PtrTy());
3352 }
3353
3354 const BufferAllocation::Slice root_buffer =
3355 assignment_.GetUniqueTopLevelSlice(root_inst).ValueOrDie();
3356 return EmitBufferPointer(root_buffer, root_inst->shape());
3357 }
3358
BindFusionArguments(const HloInstruction * fusion,FusedIrEmitter * fused_emitter)3359 void IrEmitter::BindFusionArguments(const HloInstruction* fusion,
3360 FusedIrEmitter* fused_emitter) {
3361 for (int i = 0; i < fusion->operand_count(); i++) {
3362 const HloInstruction* operand = fusion->operand(i);
3363 fused_emitter->BindGenerator(
3364 fusion->fused_parameter(i),
3365 [this, operand](llvm_ir::IrArray::Index index) {
3366 return GetIrArrayFor(operand).EmitReadArrayElement(index, &b_);
3367 });
3368 }
3369 }
3370
3371 } // namespace cpu
3372 } // namespace xla
3373