• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h"
17 
18 #include "absl/memory/memory.h"
19 #include "absl/strings/str_cat.h"
20 #include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h"
21 #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
22 #include "tensorflow/compiler/xla/service/cpu/shape_partition.h"
23 #include "tensorflow/compiler/xla/service/hlo_computation.h"
24 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
25 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
26 #include "tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h"
27 
28 namespace xla {
29 namespace cpu {
30 
31 class SimpleCostModel : public ParallelCostModel {
32  public:
SimpleCostModel(const int64 max_parallelism,const HloCostAnalysis::ShapeSizeFunction & shape_size)33   SimpleCostModel(const int64 max_parallelism,
34                   const HloCostAnalysis::ShapeSizeFunction& shape_size)
35       : max_parallelism_(max_parallelism), shape_size_(shape_size) {}
~SimpleCostModel()36   ~SimpleCostModel() override {}
37 
GetParallelTaskCount(HloInstruction * instruction)38   int64 GetParallelTaskCount(HloInstruction* instruction) override {
39     // Simple cost model based on hlo size and typical L2 cache size.
40     const int64 instruction_cost = shape_size_(instruction->shape());
41     const int64 min_cost_per_thread = 256LL << 10;  // 256KB L2 Cache size.
42     // Return target parallel task count in [1, max_parallelism_].
43     return std::min(max_parallelism_,
44                     std::max(int64{1}, instruction_cost / min_cost_per_thread));
45   }
46 
47  private:
48   const int64 max_parallelism_;
49   const HloCostAnalysis::ShapeSizeFunction shape_size_;
50 };
51 
52 class DefaultCostModel : public ParallelCostModel {
53  public:
DefaultCostModel(const int64 max_parallelism,const HloCostAnalysis::ShapeSizeFunction & shape_size,std::unique_ptr<HloCostAnalysis> cost_analysis)54   DefaultCostModel(const int64 max_parallelism,
55                    const HloCostAnalysis::ShapeSizeFunction& shape_size,
56                    std::unique_ptr<HloCostAnalysis> cost_analysis)
57       : max_parallelism_(max_parallelism),
58         shape_size_(shape_size),
59         cost_analysis_(std::move(cost_analysis)) {}
~DefaultCostModel()60   ~DefaultCostModel() override {}
61 
GetParallelTaskCount(HloInstruction * instruction)62   int64 GetParallelTaskCount(HloInstruction* instruction) override {
63     // Parameters for parallel task count computation.
64     int64 instruction_cost;
65     int64 min_cost_per_thread;
66     int64 max_parallelism;
67     // Calculate flops-to-bytes-ratio for 'instruction'.
68     const int64 bytes_accessed =
69         std::max(int64{1}, cost_analysis_->bytes_accessed(*instruction));
70     const float flops_to_bytes_ratio =
71         cost_analysis_->flop_count(*instruction) /
72         static_cast<float>(bytes_accessed);
73     // Check for I/O bound instructions.
74     if (flops_to_bytes_ratio <= 1.0) {
75       // Limit max parallelism for I/O bound instructions by assuming a
76       // sub-linear scaling function (fit based on empirical benchmark results).
77       // TODO(b/29630486) Develop system bandwidth model.
78       max_parallelism = std::min<int64>(
79           max_parallelism_,
80           std::ceil(std::sqrt(tensorflow::port::MaxParallelism())));
81       // Use shape size instruction cost and L2 cache size min per-thread cost.
82       instruction_cost = shape_size_(instruction->shape());
83       min_cost_per_thread = 256LL << 10;  // 256KB L2 Cache size.
84     } else {
85       // Use max parallelism for compute bound instructions.
86       max_parallelism = max_parallelism_;
87       // Calculate the instruction cost in cycles.
88       // TODO(b/29630486) Improve on this linear cost model.
89       // Consider making 'min_cost_per_thread' be a function of the target
90       // bandwidth limit for instructions with low arithmetic complexity.
91       instruction_cost =
92           1 * cost_analysis_->flop_count(*instruction) +
93           2 * cost_analysis_->transcendental_count(*instruction) +
94           10 * cost_analysis_->bytes_accessed(*instruction);
95       // Minimum per-thread cost is 100us of work on a 2GHz core.
96       min_cost_per_thread = 100000;
97     }
98     // Return target parallel task count in [1, max_parallelism_].
99     return std::min(max_parallelism,
100                     std::max(int64{1}, instruction_cost / min_cost_per_thread));
101   }
102 
103  private:
104   const int64 max_parallelism_;
105   const HloCostAnalysis::ShapeSizeFunction shape_size_;
106   const std::unique_ptr<HloCostAnalysis> cost_analysis_;
107 };
108 
ParallelTaskAssignment(const int64 max_parallelism,const HloCostAnalysis::ShapeSizeFunction & shape_size,HloModule * module,const TargetMachineFeatures * target_machine_features)109 ParallelTaskAssignment::ParallelTaskAssignment(
110     const int64 max_parallelism,
111     const HloCostAnalysis::ShapeSizeFunction& shape_size, HloModule* module,
112     const TargetMachineFeatures* target_machine_features)
113     : target_machine_features_(*target_machine_features) {
114   VLOG(1) << "ParallelTaskAssignment max_parallelism: " << max_parallelism;
115   // Run cost analysis on 'module'.
116   auto cost_analysis = absl::make_unique<HloCostAnalysis>(shape_size);
117   HloComputation* computation = module->entry_computation();
118   Status status = computation->root_instruction()->Accept(cost_analysis.get());
119   if (status.ok()) {
120     // Set default cost model based on 'cost_analysis'.
121     cost_model_.reset(new DefaultCostModel(max_parallelism, shape_size,
122                                            std::move(cost_analysis)));
123   } else {
124     // Fall back to a simple cost model based on hlo size and L2 cache size.
125     // Note that HloCostAnalysis can returns an error status (likely because
126     // HLOs like CustomCall are not yet implemented in the HloCostAnalysis).
127     cost_model_.reset(new SimpleCostModel(max_parallelism, shape_size));
128   }
129 }
130 
GetTargetParallelTaskCount(HloInstruction * instruction)131 int64 ParallelTaskAssignment::GetTargetParallelTaskCount(
132     HloInstruction* instruction) {
133   // Currently, we do not assign parallel tasks to instructions with at least
134   // one of the following properties:
135   // *) Internal threading (library calls to kConv, kDot, kFft, kCustomCall).
136   // *) Emit custom loops (kSelectAndScatter).
137   // *) Operations that are not thread safe (like infeed and rng).
138   // *) Tuple-shaped.
139   // *) Operations that might be implemented as an in-place
140   //    dynamic-update-slice, because we can't know how many output elements
141   //    they will write (out-of-place will touch the whole output buffer, while
142   //    in-place will only touch the updated elements).
143   // TODO(b/27458679) Parallelize instructions which are skipped here.
144   auto opcode = instruction->opcode();
145   if (llvm_ir::MayBeImplementedAsInPlaceDynamicUpdateSlice(instruction) ||
146       instruction->shape().IsTuple() || opcode == HloOpcode::kRng ||
147       opcode == HloOpcode::kConstant) {
148     return 1;
149   }
150 
151   // Only allow known good instructions.
152   if (instruction->IsElementwise() || instruction->IsLoopFusion() ||
153       opcode == HloOpcode::kBroadcast || opcode == HloOpcode::kConcatenate ||
154       opcode == HloOpcode::kDynamicSlice ||
155       opcode == HloOpcode::kDynamicUpdateSlice ||
156       opcode == HloOpcode::kGather || opcode == HloOpcode::kIota ||
157       opcode == HloOpcode::kPad || opcode == HloOpcode::kReduce ||
158       opcode == HloOpcode::kReduceWindow || opcode == HloOpcode::kReshape ||
159       opcode == HloOpcode::kReverse || opcode == HloOpcode::kSlice ||
160       opcode == HloOpcode::kTranspose ||
161       (opcode == HloOpcode::kConvolution &&
162        !PotentiallyImplementedAsEigenConvolution(*instruction,
163                                                  target_machine_features_))) {
164     // Consult 'cost_model_' to compute target parallel task count.
165     return cost_model_->GetParallelTaskCount(instruction);
166   }
167 
168   return 1;
169 }
170 
Run(HloModule * module)171 StatusOr<bool> ParallelTaskAssigner::Run(HloModule* module) {
172   XLA_VLOG_LINES(2, "ParallelTaskAssigner ENTRY");
173   XLA_VLOG_LINES(3, module->ToString());
174   // Compute target parallel task counts for all instructions in 'module'.
175   HloToParallelTasks hlo_to_parallel_tasks;
176   ComputeTargetParallelTasks(module, &hlo_to_parallel_tasks);
177 
178   // Assign parallel tasks to target specific instructions in 'module'.
179   // TODO(b/27458679) Support inter-op parallelism.
180   bool changed = AssignParallelTasks(module, hlo_to_parallel_tasks);
181 
182   XLA_VLOG_LINES(2, "ParallelTaskAssigner EXIT");
183   XLA_VLOG_LINES(3, module->ToString());
184   return changed;
185 }
186 
AssignParallelTasks(HloModule * module,const HloToParallelTasks & hlo_to_parallel_tasks)187 bool ParallelTaskAssigner::AssignParallelTasks(
188     HloModule* module, const HloToParallelTasks& hlo_to_parallel_tasks) {
189   return AssignParallelTasksHelper(module, module->entry_computation(),
190                                    hlo_to_parallel_tasks);
191 }
192 
AssignParallelTasksHelper(HloModule * module,HloComputation * computation,const HloToParallelTasks & hlo_to_parallel_tasks)193 bool ParallelTaskAssigner::AssignParallelTasksHelper(
194     HloModule* module, HloComputation* computation,
195     const HloToParallelTasks& hlo_to_parallel_tasks) {
196   bool changed = false;
197   // Snapshot set of instructions because outlining modifies the set below.
198   std::vector<HloInstruction*> instructions(computation->instructions().begin(),
199                                             computation->instructions().end());
200   for (auto* instruction : instructions) {
201     // Assign parallel tasks to sub-computations for While and Call HLOs.
202     // TODO(b/27458679) Evaluate alternative intra-op parallelism placement,
203     // and support other callable computations like reduce.
204     if (instruction->opcode() == HloOpcode::kWhile) {
205       changed |= AssignParallelTasksHelper(module, instruction->while_body(),
206                                            hlo_to_parallel_tasks);
207       continue;
208     } else if (instruction->opcode() == HloOpcode::kCall) {
209       changed |= AssignParallelTasksHelper(module, instruction->to_apply(),
210                                            hlo_to_parallel_tasks);
211       continue;
212     }
213     // Skip if no parallel tasks were computed in first pass.
214     auto it = hlo_to_parallel_tasks.find(instruction);
215     if (it == hlo_to_parallel_tasks.end()) {
216       continue;
217     }
218     // Get target parallel task count computed for 'instruction'.
219     const int64 target_parallel_task_count = (*it).second;
220     // Assign feasible dimension partitions (based on actual dimension sizes).
221     auto dim_partition_counts = ShapePartitionAssigner(instruction->shape())
222                                     .Run(target_parallel_task_count);
223     const int64 total_partition_count =
224         ShapePartitionAssigner::GetTotalPartitionCount(dim_partition_counts);
225     if (total_partition_count <= 1) {
226       // Feasible partition calculation resulting in no partitioning, so skip.
227       continue;
228     }
229 
230     // Outline 'instruction' in 'computation' for parallel task assignment.
231     auto* call = module->OutlineExpressionFromComputation(
232         {instruction}, absl::StrCat("parallel_", instruction->name()),
233         computation);
234 
235     // Set assigned dimension partitioning to 'instruction'.
236     auto* new_root = call->to_apply()->root_instruction();
237     new_root->set_outer_dimension_partitions(dim_partition_counts);
238 
239     VLOG(2) << "Assigned parallel task count: " << total_partition_count
240             << " to instruction: " << new_root->name()
241             << " parent: " << new_root->parent()->name();
242     changed = true;
243   }
244   return changed;
245 }
246 
ComputeTargetParallelTasks(HloModule * module,HloToParallelTasks * hlo_to_parallel_tasks)247 void ParallelTaskAssigner::ComputeTargetParallelTasks(
248     HloModule* module, HloToParallelTasks* hlo_to_parallel_tasks) {
249   ParallelTaskAssignment parallel_task_assignment(max_parallelism_,
250                                                   shape_size_function_, module,
251                                                   &target_machine_features_);
252 
253   // Compute parallel task counts for all instructions in 'module'.
254   for (auto* computation : module->MakeNonfusionComputations()) {
255     for (auto* instruction : computation->instructions()) {
256       // Query ParallelTaskAssignment for target parallel task count.
257       const int64 target_parallel_task_count =
258           parallel_task_assignment.GetTargetParallelTaskCount(instruction);
259       if (target_parallel_task_count > 1) {
260         hlo_to_parallel_tasks->insert(
261             {instruction, target_parallel_task_count});
262       }
263     }
264   }
265 }
266 
267 }  // namespace cpu
268 }  // namespace xla
269