1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h"
17
18 #include "absl/memory/memory.h"
19 #include "absl/strings/str_cat.h"
20 #include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h"
21 #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
22 #include "tensorflow/compiler/xla/service/cpu/shape_partition.h"
23 #include "tensorflow/compiler/xla/service/hlo_computation.h"
24 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
25 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
26 #include "tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h"
27
28 namespace xla {
29 namespace cpu {
30
31 class SimpleCostModel : public ParallelCostModel {
32 public:
SimpleCostModel(const int64 max_parallelism,const HloCostAnalysis::ShapeSizeFunction & shape_size)33 SimpleCostModel(const int64 max_parallelism,
34 const HloCostAnalysis::ShapeSizeFunction& shape_size)
35 : max_parallelism_(max_parallelism), shape_size_(shape_size) {}
~SimpleCostModel()36 ~SimpleCostModel() override {}
37
GetParallelTaskCount(HloInstruction * instruction)38 int64 GetParallelTaskCount(HloInstruction* instruction) override {
39 // Simple cost model based on hlo size and typical L2 cache size.
40 const int64 instruction_cost = shape_size_(instruction->shape());
41 const int64 min_cost_per_thread = 256LL << 10; // 256KB L2 Cache size.
42 // Return target parallel task count in [1, max_parallelism_].
43 return std::min(max_parallelism_,
44 std::max(int64{1}, instruction_cost / min_cost_per_thread));
45 }
46
47 private:
48 const int64 max_parallelism_;
49 const HloCostAnalysis::ShapeSizeFunction shape_size_;
50 };
51
52 class DefaultCostModel : public ParallelCostModel {
53 public:
DefaultCostModel(const int64 max_parallelism,const HloCostAnalysis::ShapeSizeFunction & shape_size,std::unique_ptr<HloCostAnalysis> cost_analysis)54 DefaultCostModel(const int64 max_parallelism,
55 const HloCostAnalysis::ShapeSizeFunction& shape_size,
56 std::unique_ptr<HloCostAnalysis> cost_analysis)
57 : max_parallelism_(max_parallelism),
58 shape_size_(shape_size),
59 cost_analysis_(std::move(cost_analysis)) {}
~DefaultCostModel()60 ~DefaultCostModel() override {}
61
GetParallelTaskCount(HloInstruction * instruction)62 int64 GetParallelTaskCount(HloInstruction* instruction) override {
63 // Parameters for parallel task count computation.
64 int64 instruction_cost;
65 int64 min_cost_per_thread;
66 int64 max_parallelism;
67 // Calculate flops-to-bytes-ratio for 'instruction'.
68 const int64 bytes_accessed =
69 std::max(int64{1}, cost_analysis_->bytes_accessed(*instruction));
70 const float flops_to_bytes_ratio =
71 cost_analysis_->flop_count(*instruction) /
72 static_cast<float>(bytes_accessed);
73 // Check for I/O bound instructions.
74 if (flops_to_bytes_ratio <= 1.0) {
75 // Limit max parallelism for I/O bound instructions by assuming a
76 // sub-linear scaling function (fit based on empirical benchmark results).
77 // TODO(b/29630486) Develop system bandwidth model.
78 max_parallelism = std::min<int64>(
79 max_parallelism_,
80 std::ceil(std::sqrt(tensorflow::port::MaxParallelism())));
81 // Use shape size instruction cost and L2 cache size min per-thread cost.
82 instruction_cost = shape_size_(instruction->shape());
83 min_cost_per_thread = 256LL << 10; // 256KB L2 Cache size.
84 } else {
85 // Use max parallelism for compute bound instructions.
86 max_parallelism = max_parallelism_;
87 // Calculate the instruction cost in cycles.
88 // TODO(b/29630486) Improve on this linear cost model.
89 // Consider making 'min_cost_per_thread' be a function of the target
90 // bandwidth limit for instructions with low arithmetic complexity.
91 instruction_cost =
92 1 * cost_analysis_->flop_count(*instruction) +
93 2 * cost_analysis_->transcendental_count(*instruction) +
94 10 * cost_analysis_->bytes_accessed(*instruction);
95 // Minimum per-thread cost is 100us of work on a 2GHz core.
96 min_cost_per_thread = 100000;
97 }
98 // Return target parallel task count in [1, max_parallelism_].
99 return std::min(max_parallelism,
100 std::max(int64{1}, instruction_cost / min_cost_per_thread));
101 }
102
103 private:
104 const int64 max_parallelism_;
105 const HloCostAnalysis::ShapeSizeFunction shape_size_;
106 const std::unique_ptr<HloCostAnalysis> cost_analysis_;
107 };
108
ParallelTaskAssignment(const int64 max_parallelism,const HloCostAnalysis::ShapeSizeFunction & shape_size,HloModule * module,const TargetMachineFeatures * target_machine_features)109 ParallelTaskAssignment::ParallelTaskAssignment(
110 const int64 max_parallelism,
111 const HloCostAnalysis::ShapeSizeFunction& shape_size, HloModule* module,
112 const TargetMachineFeatures* target_machine_features)
113 : target_machine_features_(*target_machine_features) {
114 VLOG(1) << "ParallelTaskAssignment max_parallelism: " << max_parallelism;
115 // Run cost analysis on 'module'.
116 auto cost_analysis = absl::make_unique<HloCostAnalysis>(shape_size);
117 HloComputation* computation = module->entry_computation();
118 Status status = computation->root_instruction()->Accept(cost_analysis.get());
119 if (status.ok()) {
120 // Set default cost model based on 'cost_analysis'.
121 cost_model_.reset(new DefaultCostModel(max_parallelism, shape_size,
122 std::move(cost_analysis)));
123 } else {
124 // Fall back to a simple cost model based on hlo size and L2 cache size.
125 // Note that HloCostAnalysis can returns an error status (likely because
126 // HLOs like CustomCall are not yet implemented in the HloCostAnalysis).
127 cost_model_.reset(new SimpleCostModel(max_parallelism, shape_size));
128 }
129 }
130
GetTargetParallelTaskCount(HloInstruction * instruction)131 int64 ParallelTaskAssignment::GetTargetParallelTaskCount(
132 HloInstruction* instruction) {
133 // Currently, we do not assign parallel tasks to instructions with at least
134 // one of the following properties:
135 // *) Internal threading (library calls to kConv, kDot, kFft, kCustomCall).
136 // *) Emit custom loops (kSelectAndScatter).
137 // *) Operations that are not thread safe (like infeed and rng).
138 // *) Tuple-shaped.
139 // *) Operations that might be implemented as an in-place
140 // dynamic-update-slice, because we can't know how many output elements
141 // they will write (out-of-place will touch the whole output buffer, while
142 // in-place will only touch the updated elements).
143 // TODO(b/27458679) Parallelize instructions which are skipped here.
144 auto opcode = instruction->opcode();
145 if (llvm_ir::MayBeImplementedAsInPlaceDynamicUpdateSlice(instruction) ||
146 instruction->shape().IsTuple() || opcode == HloOpcode::kRng ||
147 opcode == HloOpcode::kConstant) {
148 return 1;
149 }
150
151 // Only allow known good instructions.
152 if (instruction->IsElementwise() || instruction->IsLoopFusion() ||
153 opcode == HloOpcode::kBroadcast || opcode == HloOpcode::kConcatenate ||
154 opcode == HloOpcode::kDynamicSlice ||
155 opcode == HloOpcode::kDynamicUpdateSlice ||
156 opcode == HloOpcode::kGather || opcode == HloOpcode::kIota ||
157 opcode == HloOpcode::kPad || opcode == HloOpcode::kReduce ||
158 opcode == HloOpcode::kReduceWindow || opcode == HloOpcode::kReshape ||
159 opcode == HloOpcode::kReverse || opcode == HloOpcode::kSlice ||
160 opcode == HloOpcode::kTranspose ||
161 (opcode == HloOpcode::kConvolution &&
162 !PotentiallyImplementedAsEigenConvolution(*instruction,
163 target_machine_features_))) {
164 // Consult 'cost_model_' to compute target parallel task count.
165 return cost_model_->GetParallelTaskCount(instruction);
166 }
167
168 return 1;
169 }
170
Run(HloModule * module)171 StatusOr<bool> ParallelTaskAssigner::Run(HloModule* module) {
172 XLA_VLOG_LINES(2, "ParallelTaskAssigner ENTRY");
173 XLA_VLOG_LINES(3, module->ToString());
174 // Compute target parallel task counts for all instructions in 'module'.
175 HloToParallelTasks hlo_to_parallel_tasks;
176 ComputeTargetParallelTasks(module, &hlo_to_parallel_tasks);
177
178 // Assign parallel tasks to target specific instructions in 'module'.
179 // TODO(b/27458679) Support inter-op parallelism.
180 bool changed = AssignParallelTasks(module, hlo_to_parallel_tasks);
181
182 XLA_VLOG_LINES(2, "ParallelTaskAssigner EXIT");
183 XLA_VLOG_LINES(3, module->ToString());
184 return changed;
185 }
186
AssignParallelTasks(HloModule * module,const HloToParallelTasks & hlo_to_parallel_tasks)187 bool ParallelTaskAssigner::AssignParallelTasks(
188 HloModule* module, const HloToParallelTasks& hlo_to_parallel_tasks) {
189 return AssignParallelTasksHelper(module, module->entry_computation(),
190 hlo_to_parallel_tasks);
191 }
192
AssignParallelTasksHelper(HloModule * module,HloComputation * computation,const HloToParallelTasks & hlo_to_parallel_tasks)193 bool ParallelTaskAssigner::AssignParallelTasksHelper(
194 HloModule* module, HloComputation* computation,
195 const HloToParallelTasks& hlo_to_parallel_tasks) {
196 bool changed = false;
197 // Snapshot set of instructions because outlining modifies the set below.
198 std::vector<HloInstruction*> instructions(computation->instructions().begin(),
199 computation->instructions().end());
200 for (auto* instruction : instructions) {
201 // Assign parallel tasks to sub-computations for While and Call HLOs.
202 // TODO(b/27458679) Evaluate alternative intra-op parallelism placement,
203 // and support other callable computations like reduce.
204 if (instruction->opcode() == HloOpcode::kWhile) {
205 changed |= AssignParallelTasksHelper(module, instruction->while_body(),
206 hlo_to_parallel_tasks);
207 continue;
208 } else if (instruction->opcode() == HloOpcode::kCall) {
209 changed |= AssignParallelTasksHelper(module, instruction->to_apply(),
210 hlo_to_parallel_tasks);
211 continue;
212 }
213 // Skip if no parallel tasks were computed in first pass.
214 auto it = hlo_to_parallel_tasks.find(instruction);
215 if (it == hlo_to_parallel_tasks.end()) {
216 continue;
217 }
218 // Get target parallel task count computed for 'instruction'.
219 const int64 target_parallel_task_count = (*it).second;
220 // Assign feasible dimension partitions (based on actual dimension sizes).
221 auto dim_partition_counts = ShapePartitionAssigner(instruction->shape())
222 .Run(target_parallel_task_count);
223 const int64 total_partition_count =
224 ShapePartitionAssigner::GetTotalPartitionCount(dim_partition_counts);
225 if (total_partition_count <= 1) {
226 // Feasible partition calculation resulting in no partitioning, so skip.
227 continue;
228 }
229
230 // Outline 'instruction' in 'computation' for parallel task assignment.
231 auto* call = module->OutlineExpressionFromComputation(
232 {instruction}, absl::StrCat("parallel_", instruction->name()),
233 computation);
234
235 // Set assigned dimension partitioning to 'instruction'.
236 auto* new_root = call->to_apply()->root_instruction();
237 new_root->set_outer_dimension_partitions(dim_partition_counts);
238
239 VLOG(2) << "Assigned parallel task count: " << total_partition_count
240 << " to instruction: " << new_root->name()
241 << " parent: " << new_root->parent()->name();
242 changed = true;
243 }
244 return changed;
245 }
246
ComputeTargetParallelTasks(HloModule * module,HloToParallelTasks * hlo_to_parallel_tasks)247 void ParallelTaskAssigner::ComputeTargetParallelTasks(
248 HloModule* module, HloToParallelTasks* hlo_to_parallel_tasks) {
249 ParallelTaskAssignment parallel_task_assignment(max_parallelism_,
250 shape_size_function_, module,
251 &target_machine_features_);
252
253 // Compute parallel task counts for all instructions in 'module'.
254 for (auto* computation : module->MakeNonfusionComputations()) {
255 for (auto* instruction : computation->instructions()) {
256 // Query ParallelTaskAssignment for target parallel task count.
257 const int64 target_parallel_task_count =
258 parallel_task_assignment.GetTargetParallelTaskCount(instruction);
259 if (target_parallel_task_count > 1) {
260 hlo_to_parallel_tasks->insert(
261 {instruction, target_parallel_task_count});
262 }
263 }
264 }
265 }
266
267 } // namespace cpu
268 } // namespace xla
269