• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <vector>
18 
19 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_nested.h"
20 
21 #include "llvm/IR/BasicBlock.h"
22 #include "llvm/IR/Function.h"
23 #include "llvm/IR/IRBuilder.h"
24 #include "llvm/IR/Instructions.h"
25 #include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h"
26 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h"
27 #include "tensorflow/compiler/xla/service/hlo_computation.h"
28 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
29 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
30 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
31 #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h"
32 #include "tensorflow/compiler/xla/service/name_uniquer.h"
33 #include "tensorflow/core/lib/core/status.h"
34 
35 namespace xla {
36 namespace gpu {
37 
IrEmitterNested(const HloModuleConfig & hlo_module_config,const HloComputation & nested_computation,IrEmitterContext * ir_emitter_context)38 IrEmitterNested::IrEmitterNested(const HloModuleConfig& hlo_module_config,
39                                  const HloComputation& nested_computation,
40                                  IrEmitterContext* ir_emitter_context)
41     : IrEmitter(hlo_module_config, ir_emitter_context, /*is_nested=*/true),
42       nested_computation_(nested_computation) {}
43 
Create(const HloModuleConfig & hlo_module_config,const HloComputation & nested_computation,IrEmitterContext * ir_emitter_context)44 StatusOr<std::unique_ptr<IrEmitterNested>> IrEmitterNested::Create(
45     const HloModuleConfig& hlo_module_config,
46     const HloComputation& nested_computation,
47     IrEmitterContext* ir_emitter_context) {
48   std::unique_ptr<IrEmitterNested> emitter(new IrEmitterNested(
49       hlo_module_config, nested_computation, ir_emitter_context));
50   TF_RETURN_IF_ERROR(emitter->EmitConstants(nested_computation, false));
51   return emitter;
52 }
53 
54 // Nested function serves the same purpose on GPU as a thread-local function on
55 // a CPU.
CodegenNestedComputation()56 Status IrEmitterNested::CodegenNestedComputation() {
57   std::vector<const HloInstruction*> io_hlos;
58   std::vector<llvm::Type*> argument_types;
59   std::vector<int64> argument_dereferenceable_bytes;
60   for (const HloInstruction* param :
61        nested_computation_.parameter_instructions()) {
62     io_hlos.push_back(param);
63     const Shape& param_shape = param->shape();
64     argument_types.push_back(
65         llvm_ir::ShapeToIrType(param_shape, module_)->getPointerTo());
66     int64 param_size =
67         llvm_ir::ByteSizeOf(param_shape, module_->getDataLayout());
68     argument_dereferenceable_bytes.push_back(param_size);
69   }
70 
71   const HloInstruction* root = nested_computation_.root_instruction();
72   {
73     const Shape& root_shape = root->shape();
74     argument_types.push_back(
75         llvm_ir::ShapeToIrType(root_shape, module_)->getPointerTo());
76     int64 root_size = llvm_ir::ByteSizeOf(
77         root_shape, ir_emitter_context_->llvm_module()->getDataLayout());
78     argument_dereferenceable_bytes.push_back(root_size);
79   }
80 
81   llvm::FunctionType* function_type =
82       llvm::FunctionType::get(b_.getVoidTy(), argument_types, false);
83   llvm::Function* function = llvm::Function::Create(
84       function_type,                       // The function type.
85       llvm::GlobalValue::InternalLinkage,  // The linkage type.
86       ir_emitter_context_->name_uniquer()->GetUniqueName(
87           llvm_ir::SanitizeFunctionName(
88               nested_computation_.name())),  // The name of the function.
89       ir_emitter_context_->llvm_module());   // The parent LLVM module.
90   for (size_t arg_no = 0; arg_no < argument_dereferenceable_bytes.size();
91        ++arg_no) {
92     int64 arg_size = argument_dereferenceable_bytes[arg_no];
93     if (arg_size > 0) {
94       function->addDereferenceableAttr(arg_no + 1, arg_size);
95     }
96   }
97 
98   // TODO(b/65380986): Investigate if adding fast math flags for generated
99   // kernels makes sense.
100 
101   llvm::BasicBlock* entry_bb =
102       llvm::BasicBlock::Create(function->getContext(), "entry", function);
103   // Emit a "return void" at entry_bb's end, and sets the insert point before
104   // that return instruction.
105   llvm::ReturnInst* ret_instr =
106       llvm::ReturnInst::Create(function->getContext(), entry_bb);
107   b_.SetInsertPoint(ret_instr);
108 
109   std::vector<const HloInstruction*> non_io_hlos;
110   non_io_hlos.push_back(root);
111   for (const auto* hlo : nested_computation_.instructions()) {
112     if (hlo->opcode() != HloOpcode::kParameter &&
113         hlo != nested_computation_.root_instruction()) {
114       non_io_hlos.push_back(hlo);
115     }
116   }
117   bindings_.EmitBasePointersForHlos(io_hlos, non_io_hlos);
118 
119   TF_RETURN_IF_ERROR(nested_computation_.root_instruction()->Accept(this));
120   b_.SetInsertPoint(ret_instr);
121 
122   // Function epilogue: copy the output value back.
123   {
124     // TODO(cheshire) Duplication vs. EmitThreadLocalFunctionEpilogue
125     const HloInstruction* root_instruction =
126         nested_computation_.root_instruction();
127     llvm::Value* root_value = bindings_.GetBasePointer(*root_instruction);
128     const Shape& return_shape = root_instruction->shape();
129 
130     // Last argument is the out parameter.
131     llvm::Argument* out_parameter = std::prev(function->arg_end(), 1);
132 
133     if (ShapeUtil::IsScalar(return_shape)) {
134       llvm::Value* ret_value = Load(root_value, "load_ret_value");
135       Store(ret_value,
136             BitCast(out_parameter, root_value->getType(), "bitcast_ret_value"));
137     } else {
138       CHECK(return_shape.IsTuple());
139       llvm::Type* tuple_type = llvm_ir::ShapeToIrType(return_shape, module_);
140       llvm::Type* tuple_type_ptr = tuple_type->getPointerTo();
141       llvm::Value* tuple_ptr = BitCast(out_parameter, tuple_type_ptr);
142 
143       for (int i = 0; i < return_shape.tuple_shapes_size(); i++) {
144         const Shape& element_shape = return_shape.tuple_shapes(i);
145         llvm::Value* destination =
146             llvm_ir::EmitGetTupleElement(element_shape,
147                                          /*index=*/i,
148                                          /*alignment=*/1, tuple_ptr, &b_);
149         llvm::Value* source =
150             llvm_ir::EmitGetTupleElement(element_shape,
151                                          /*index=*/i,
152                                          /*alignment=*/1, root_value, &b_);
153         Store(Load(source), destination);
154       }
155     }
156   }
157   b_.SetInsertPoint(ret_instr);
158   emitted_function_ = function;
159   return Status::OK();
160 }
161 
HandleParameter(HloInstruction * parameter)162 Status IrEmitterNested::HandleParameter(HloInstruction* parameter) {
163   return Status::OK();
164 }
165 
EmitTargetElementLoop(const HloInstruction & hlo,const llvm_ir::ElementGenerator & element_generator)166 Status IrEmitterNested::EmitTargetElementLoop(
167     const HloInstruction& hlo,
168     const llvm_ir::ElementGenerator& element_generator) {
169   // For MOF we give the loop emitter an array for every output it should
170   // generate.
171   if (hlo.shape().IsTuple()) {
172     std::vector<llvm_ir::IrArray> target_arrays =
173         ConstructIrArrayForOutputs(hlo);
174     TF_RETURN_IF_ERROR(
175         llvm_ir::LoopEmitter(element_generator, target_arrays, &b_).EmitLoop());
176     llvm_ir::EmitTuple(GetIrArray(hlo, hlo), target_arrays, &b_);
177     return Status::OK();
178   }
179   return llvm_ir::LoopEmitter(element_generator, GetIrArray(hlo, hlo), &b_)
180       .EmitLoop();
181 }
182 
183 }  // namespace gpu
184 }  // namespace xla
185