1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h"
17
18 #include "absl/memory/memory.h"
19 #include "absl/strings/str_cat.h"
20 #include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h"
21 #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
22 #include "tensorflow/compiler/xla/service/cpu/shape_partition.h"
23 #include "tensorflow/compiler/xla/service/hlo_computation.h"
24 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
25 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
26
27 namespace xla {
28 namespace cpu {
29
30 class SimpleCostModel : public ParallelCostModel {
31 public:
SimpleCostModel(const int64 max_parallelism,const HloCostAnalysis::ShapeSizeFunction & shape_size)32 SimpleCostModel(const int64 max_parallelism,
33 const HloCostAnalysis::ShapeSizeFunction& shape_size)
34 : max_parallelism_(max_parallelism), shape_size_(shape_size) {}
~SimpleCostModel()35 ~SimpleCostModel() override {}
36
GetParallelTaskCount(HloInstruction * instruction)37 int64 GetParallelTaskCount(HloInstruction* instruction) override {
38 // Simple cost model based on hlo size and typical L2 cache size.
39 const int64 instruction_cost = shape_size_(instruction->shape());
40 const int64 min_cost_per_thread = 256LL << 10; // 256KB L2 Cache size.
41 // Return target parallel task count in [1, max_parallelism_].
42 return std::min(max_parallelism_,
43 std::max(int64{1}, instruction_cost / min_cost_per_thread));
44 }
45
46 private:
47 const int64 max_parallelism_;
48 const HloCostAnalysis::ShapeSizeFunction shape_size_;
49 };
50
51 class DefaultCostModel : public ParallelCostModel {
52 public:
DefaultCostModel(const int64 max_parallelism,const HloCostAnalysis::ShapeSizeFunction & shape_size,std::unique_ptr<HloCostAnalysis> cost_analysis)53 DefaultCostModel(const int64 max_parallelism,
54 const HloCostAnalysis::ShapeSizeFunction& shape_size,
55 std::unique_ptr<HloCostAnalysis> cost_analysis)
56 : max_parallelism_(max_parallelism),
57 shape_size_(shape_size),
58 cost_analysis_(std::move(cost_analysis)) {}
~DefaultCostModel()59 ~DefaultCostModel() override {}
60
GetParallelTaskCount(HloInstruction * instruction)61 int64 GetParallelTaskCount(HloInstruction* instruction) override {
62 // Parameters for parallel task count computation.
63 int64 instruction_cost;
64 int64 min_cost_per_thread;
65 int64 max_parallelism;
66 // Calculate flops-to-bytes-ratio for 'instruction'.
67 const int64 bytes_accessed =
68 std::max(int64{1}, cost_analysis_->bytes_accessed(*instruction));
69 const float flops_to_bytes_ratio =
70 cost_analysis_->flop_count(*instruction) /
71 static_cast<float>(bytes_accessed);
72 // Check for I/O bound instructions.
73 if (flops_to_bytes_ratio <= 1.0) {
74 // Limit max parallelism for I/O bound instructions by assuming a
75 // sub-linear scaling function (fit based on empirical benchmark results).
76 // TODO(b/29630486) Develop system bandwidth model.
77 max_parallelism =
78 std::ceil(std::sqrt(tensorflow::port::NumSchedulableCPUs()));
79 // Use shape size instruction cost and L2 cache size min per-thread cost.
80 instruction_cost = shape_size_(instruction->shape());
81 min_cost_per_thread = 256LL << 10; // 256KB L2 Cache size.
82 } else {
83 // Use max parallelism for compute bound instructions.
84 max_parallelism = max_parallelism_;
85 // Calculate the instruction cost in cycles.
86 // TODO(b/29630486) Improve on this linear cost model.
87 // Consider making 'min_cost_per_thread' be a function of the target
88 // bandwidth limit for instructions with low arithmetic complexity.
89 instruction_cost =
90 1 * cost_analysis_->flop_count(*instruction) +
91 2 * cost_analysis_->transcendental_count(*instruction) +
92 10 * cost_analysis_->bytes_accessed(*instruction);
93 // Minimum per-thread cost is 100us of work on a 2GHz core.
94 min_cost_per_thread = 100000;
95 }
96 // Return target parallel task count in [1, max_parallelism_].
97 return std::min(max_parallelism,
98 std::max(int64{1}, instruction_cost / min_cost_per_thread));
99 }
100
101 private:
102 const int64 max_parallelism_;
103 const HloCostAnalysis::ShapeSizeFunction shape_size_;
104 const std::unique_ptr<HloCostAnalysis> cost_analysis_;
105 };
106
ParallelTaskAssignment(const int64 max_parallelism,const HloCostAnalysis::ShapeSizeFunction & shape_size,HloModule * module,const TargetMachineFeatures * target_machine_features)107 ParallelTaskAssignment::ParallelTaskAssignment(
108 const int64 max_parallelism,
109 const HloCostAnalysis::ShapeSizeFunction& shape_size, HloModule* module,
110 const TargetMachineFeatures* target_machine_features)
111 : target_machine_features_(*target_machine_features) {
112 VLOG(1) << "ParallelTaskAssignment max_parallelism: " << max_parallelism;
113 // Run cost analysis on 'module'.
114 auto cost_analysis = absl::make_unique<HloCostAnalysis>(shape_size);
115 HloComputation* computation = module->entry_computation();
116 Status status = computation->root_instruction()->Accept(cost_analysis.get());
117 if (status.ok()) {
118 // Set default cost model based on 'cost_analysis'.
119 cost_model_.reset(new DefaultCostModel(max_parallelism, shape_size,
120 std::move(cost_analysis)));
121 } else {
122 // Fall back to a simple cost model based on hlo size and L2 cache size.
123 // Note that HloCostAnalysis can returns an error status (likely because
124 // HLOs like CustomCall are not yet implemented in the HloCostAnalysis).
125 cost_model_.reset(new SimpleCostModel(max_parallelism, shape_size));
126 }
127 }
128
GetTargetParallelTaskCount(HloInstruction * instruction)129 int64 ParallelTaskAssignment::GetTargetParallelTaskCount(
130 HloInstruction* instruction) {
131 // Currently, we do not assign parallel tasks to instructions with at least
132 // one of the following properties:
133 // *) Internal threading (library calls to kConv, kDot, kFft, kCustomCall).
134 // *) Emit custom loops (kSelectAndScatter).
135 // *) Operations that are not thread safe (like infeed and rng).
136 // *) Tuple-shaped.
137 // TODO(b/27458679) Parallelize instructions which are skipped here.
138 auto opcode = instruction->opcode();
139 if (opcode == HloOpcode::kParameter || opcode == HloOpcode::kConstant ||
140 opcode == HloOpcode::kCall || opcode == HloOpcode::kCustomCall ||
141 opcode == HloOpcode::kDot || opcode == HloOpcode::kSelectAndScatter ||
142 opcode == HloOpcode::kGetTupleElement || opcode == HloOpcode::kBitcast ||
143 opcode == HloOpcode::kFft || opcode == HloOpcode::kInfeed ||
144 opcode == HloOpcode::kOutfeed || opcode == HloOpcode::kRng ||
145 opcode == HloOpcode::kSort ||
146 (opcode == HloOpcode::kConvolution &&
147 PotentiallyImplementedAsEigenConvolution(*instruction,
148 target_machine_features_)) ||
149 (opcode == HloOpcode::kFusion &&
150 instruction->fusion_kind() != HloInstruction::FusionKind::kLoop) ||
151 instruction->shape().IsTuple()) {
152 return 1;
153 }
154
155 // Consult 'cost_model_' to compute target parallel task count.
156 return cost_model_->GetParallelTaskCount(instruction);
157 }
158
Run(HloModule * module)159 StatusOr<bool> ParallelTaskAssigner::Run(HloModule* module) {
160 XLA_VLOG_LINES(2, "ParallelTaskAssigner ENTRY");
161 XLA_VLOG_LINES(3, module->ToString());
162 // Compute target parallel task counts for all instructions in 'module'.
163 HloToParallelTasks hlo_to_parallel_tasks;
164 ComputeTargetParallelTasks(module, &hlo_to_parallel_tasks);
165
166 // Assign parallel tasks to target specific instructions in 'module'.
167 // TODO(b/27458679) Support inter-op parallelism.
168 bool changed = AssignParallelTasks(module, hlo_to_parallel_tasks);
169
170 XLA_VLOG_LINES(2, "ParallelTaskAssigner EXIT");
171 XLA_VLOG_LINES(3, module->ToString());
172 return changed;
173 }
174
AssignParallelTasks(HloModule * module,const HloToParallelTasks & hlo_to_parallel_tasks)175 bool ParallelTaskAssigner::AssignParallelTasks(
176 HloModule* module, const HloToParallelTasks& hlo_to_parallel_tasks) {
177 return AssignParallelTasksHelper(module, module->entry_computation(),
178 hlo_to_parallel_tasks);
179 }
180
AssignParallelTasksHelper(HloModule * module,HloComputation * computation,const HloToParallelTasks & hlo_to_parallel_tasks)181 bool ParallelTaskAssigner::AssignParallelTasksHelper(
182 HloModule* module, HloComputation* computation,
183 const HloToParallelTasks& hlo_to_parallel_tasks) {
184 bool changed = false;
185 // Snapshot set of instructions because outlining modifies the set below.
186 std::vector<HloInstruction*> instructions(computation->instructions().begin(),
187 computation->instructions().end());
188 for (auto* instruction : instructions) {
189 // Assign parallel tasks to sub-computations for While and Call HLOs.
190 // TODO(b/27458679) Evaluate alternative intra-op parallelsim placement,
191 // and support other callable computations like reduce.
192 if (instruction->opcode() == HloOpcode::kWhile) {
193 changed |= AssignParallelTasksHelper(module, instruction->while_body(),
194 hlo_to_parallel_tasks);
195 continue;
196 } else if (instruction->opcode() == HloOpcode::kCall) {
197 changed |= AssignParallelTasksHelper(module, instruction->to_apply(),
198 hlo_to_parallel_tasks);
199 continue;
200 }
201 // Skip if no parallel tasks were computed in first pass.
202 auto it = hlo_to_parallel_tasks.find(instruction);
203 if (it == hlo_to_parallel_tasks.end()) {
204 continue;
205 }
206 // Get target parallel task count computed for 'instruction'.
207 const int64 target_parallel_task_count = (*it).second;
208 // Assign feasible dimension partitions (based on actual dimension sizes).
209 auto dim_partition_counts = ShapePartitionAssigner(instruction->shape())
210 .Run(target_parallel_task_count);
211 const int64 total_partition_count =
212 ShapePartitionAssigner::GetTotalPartitionCount(dim_partition_counts);
213 if (total_partition_count <= 1) {
214 // Feasible partition calculation resulting in no partitioning, so skip.
215 continue;
216 }
217
218 // Outline 'instruction' in 'computation' for parallel task assignment.
219 auto* call = module->OutlineExpressionFromComputation(
220 {instruction}, absl::StrCat("parallel_", instruction->name()),
221 computation);
222
223 // Set assigned dimension partitioning to 'instruction'.
224 auto* new_root = call->to_apply()->root_instruction();
225 new_root->set_outer_dimension_partitions(dim_partition_counts);
226
227 VLOG(2) << "Assigned parallel task count: " << total_partition_count
228 << " to instruction: " << new_root->name()
229 << " parent: " << new_root->parent()->name();
230 changed = true;
231 }
232 return changed;
233 }
234
ComputeTargetParallelTasks(HloModule * module,HloToParallelTasks * hlo_to_parallel_tasks)235 void ParallelTaskAssigner::ComputeTargetParallelTasks(
236 HloModule* module, HloToParallelTasks* hlo_to_parallel_tasks) {
237 ParallelTaskAssignment parallel_task_assignment(max_parallelism_,
238 shape_size_function_, module,
239 &target_machine_features_);
240
241 // Compute parallel task counts for all instructions in 'module'.
242 for (auto* computation : module->computations()) {
243 if (computation->IsFusionComputation()) {
244 continue;
245 }
246 for (auto* instruction : computation->instructions()) {
247 // Query ParallelTaskAssignment for target parallel task count.
248 const int64 target_parallel_task_count =
249 parallel_task_assignment.GetTargetParallelTaskCount(instruction);
250 if (target_parallel_task_count > 1) {
251 hlo_to_parallel_tasks->insert(
252 {instruction, target_parallel_task_count});
253 }
254 }
255 }
256 }
257
258 } // namespace cpu
259 } // namespace xla
260