• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
17 
18 #include <algorithm>
19 #include <functional>
20 
21 #include "llvm/IR/BasicBlock.h"
22 #include "llvm/IR/Value.h"
23 #include "tensorflow/compiler/xla/map_util.h"
24 #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
25 #include "tensorflow/compiler/xla/service/fusion_node_indexing_evaluation.h"
26 #include "tensorflow/compiler/xla/service/hlo_computation.h"
27 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
28 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
29 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
30 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
31 #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h"
32 #include "tensorflow/compiler/xla/shape.h"
33 #include "tensorflow/compiler/xla/shape_util.h"
34 #include "tensorflow/compiler/xla/status_macros.h"
35 #include "tensorflow/compiler/xla/statusor.h"
36 #include "tensorflow/compiler/xla/util.h"
37 #include "tensorflow/core/platform/logging.h"
38 
39 namespace xla {
40 
41 using llvm_ir::IrArray;
42 
DefaultAction(const HloInstruction * hlo)43 Status FusedIrEmitter::DefaultAction(const HloInstruction* hlo) {
44   indexed_generators_[hlo] =
45       [=](const IrArray::Index& index) -> StatusOr<llvm::Value*> {
46     if (llvm::Value* generated_value = FindOrDefault(
47             generated_value_cache_[hlo], index.multidim(), nullptr)) {
48       llvm::BasicBlock* generated_value_bb = nullptr;
49       if (auto* generated_instruction =
50               llvm::dyn_cast<llvm::Instruction>(generated_value)) {
51         generated_value_bb = generated_instruction->getParent();
52       }
53       // Ideally, we should be able to reuse the cached generated value if it
54       // dominates the current insertion block. However, the check for dominance
55       // can be expensive and unreliable when the function is being constructed.
56       //
57       // It's also worth experimenting what if we don't do caching at all.
58       // LLVM's CSE or GVN should be able to easily merge common subexpressions
59       // that would be regenerated without caching. But this might increase the
60       // JIT compilation time.
61       if (generated_value_bb == nullptr ||
62           generated_value_bb == b_->GetInsertBlock()) {
63         VLOG(3) << "The cached generated value is reused.";
64         return generated_value;
65       }
66       VLOG(3) << "The cached generated value can't be reused, because it is in "
67                  "a different BB ("
68               << generated_value_bb->getName().str()
69               << ") from the current insertion block ("
70               << b_->GetInsertBlock()->getName().str() << ").";
71     }
72 
73     TF_ASSIGN_OR_RETURN(llvm::Value* const generated_value,
74                         elemental_emitter_->MakeElementGenerator(
75                             hlo, indexed_generators_)(index));
76     generated_value_cache_[hlo][index.multidim()] = generated_value;
77     return generated_value;
78   };
79   return Status::OK();
80 }
81 
HandleConstant(const HloInstruction * constant)82 Status FusedIrEmitter::HandleConstant(const HloInstruction* constant) {
83   indexed_generators_[constant] = [=](const IrArray::Index& index) {
84     const Literal& literal = constant->literal();
85     llvm::Constant* initializer =
86         llvm_ir::ConvertLiteralToIrConstant(literal, module_);
87     llvm::GlobalVariable* global = new llvm::GlobalVariable(
88         *b_->GetInsertBlock()->getModule(), initializer->getType(),
89         /*isConstant=*/true,
90         /*Linkage=*/llvm::GlobalValue::PrivateLinkage,
91         /*Initializer=*/initializer,
92         /*Name=*/"", /*InsertBefore=*/nullptr,
93         /*TLMode=*/llvm::GlobalValue::NotThreadLocal,
94         /*AddressSpace=*/0,
95         /*isExternallyInitialized=*/false);
96 
97     global->setUnnamedAddr(llvm::GlobalVariable::UnnamedAddr::Global);
98     llvm::Constant* shape_constant =
99         llvm::ConstantExpr::getPointerBitCastOrAddrSpaceCast(
100             global,
101             llvm_ir::ShapeToIrType(literal.shape(), module_)->getPointerTo());
102     return IrArray(shape_constant, constant->shape())
103         .EmitReadArrayElement(index, b_, constant->name());
104   };
105 
106   return Status::OK();
107 }
108 
HandleGetTupleElement(const HloInstruction * get_tuple_element)109 Status FusedIrEmitter::HandleGetTupleElement(
110     const HloInstruction* get_tuple_element) {
111   return InternalError("Tuple parameters are not supported for fusion");
112 }
113 
HandleParameter(const HloInstruction * parameter)114 Status FusedIrEmitter::HandleParameter(const HloInstruction* parameter) {
115   if (indexed_generators_.find(parameter) == indexed_generators_.end()) {
116     return InvalidArgument("Unbound parameter: %s", parameter->ToString());
117   }
118   return Status::OK();
119 }
120 
HandleTuple(const HloInstruction * tuple)121 Status FusedIrEmitter::HandleTuple(const HloInstruction* tuple) {
122   absl::Span<HloInstruction* const> operands(tuple->operands());
123   std::vector<llvm::Type*> operand_elemental_ir_types;
124   for (HloInstruction* operand : operands) {
125     operand_elemental_ir_types.push_back(llvm_ir::PrimitiveTypeToIrType(
126         operand->shape().element_type(), module_));
127   }
128   indexed_generators_[tuple] =
129       [=](const IrArray::Index& index) -> StatusOr<llvm::Value*> {
130     llvm::Value* ret = llvm::UndefValue::get(
131         llvm::StructType::get(b_->getContext(), operand_elemental_ir_types));
132     for (size_t i = 0; i < ShapeUtil::TupleElementCount(tuple->shape()); ++i) {
133       TF_ASSIGN_OR_RETURN(llvm::Value * val_i,
134                           indexed_generators_[operands[i]](index));
135       ret = b_->CreateInsertValue(ret, val_i, i);
136     }
137     return ret;
138   };
139   return Status::OK();
140 }
141 
IsFusedIrEmitterInefficient(const HloInstruction * consumer,const HloInstruction * producer)142 bool FusedIrEmitter::IsFusedIrEmitterInefficient(
143     const HloInstruction* consumer, const HloInstruction* producer) {
144   if (consumer->opcode() != HloOpcode::kFusion) {
145     return false;
146   }
147   FusionNodeIndexingEvaluation eval_consumer(consumer);
148   if (producer->opcode() != HloOpcode::kFusion) {
149     return eval_consumer.CodeDuplicationTooHigh(producer);
150   }
151   // If 'producer' is a fusion node as well, also evaluate it. Pass the
152   // evaluated duplication of the fusion node if it is merged into consumer.
153   FusionNodeIndexingEvaluation eval_producer(
154       producer, eval_consumer.EvaluateEmittedInstructions(producer));
155   return eval_producer.MaxCodeDuplicationTooHigh();
156 }
157 
GetGenerator(const HloInstruction * instruction)158 StatusOr<FusedIrEmitter::IndexedGenerator> FusedIrEmitter::GetGenerator(
159     const HloInstruction* instruction) {
160   std::vector<const HloInstruction*> stack;
161   stack.push_back(instruction);
162   while (!stack.empty()) {
163     const HloInstruction* instr = stack.back();
164     stack.pop_back();
165     if (indexed_generators_.count(instr)) {
166       continue;
167     }
168     for (const HloInstruction* operand : instr->operands()) {
169       stack.push_back(operand);
170     }
171     switch (instr->opcode()) {
172       case HloOpcode::kConstant:
173         TF_RETURN_IF_ERROR(HandleConstant(instr));
174         break;
175       case HloOpcode::kGetTupleElement:
176         TF_RETURN_IF_ERROR(HandleGetTupleElement(instr));
177         break;
178       case HloOpcode::kParameter:
179         TF_RETURN_IF_ERROR(HandleParameter(instr));
180         break;
181       case HloOpcode::kTuple:
182         TF_RETURN_IF_ERROR(HandleTuple(instr));
183         break;
184       default:
185         TF_RETURN_IF_ERROR(DefaultAction(instr));
186         break;
187     }
188     CHECK(indexed_generators_.count(instr));
189   }
190   return indexed_generators_.at(instruction);
191 }
192 
193 }  // namespace xla
194