• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h"
17 
18 #include "absl/memory/memory.h"
19 #include "absl/strings/str_cat.h"
20 #include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h"
21 #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
22 #include "tensorflow/compiler/xla/service/cpu/shape_partition.h"
23 #include "tensorflow/compiler/xla/service/hlo_computation.h"
24 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
25 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
26 
27 namespace xla {
28 namespace cpu {
29 
30 class SimpleCostModel : public ParallelCostModel {
31  public:
SimpleCostModel(const int64 max_parallelism,const HloCostAnalysis::ShapeSizeFunction & shape_size)32   SimpleCostModel(const int64 max_parallelism,
33                   const HloCostAnalysis::ShapeSizeFunction& shape_size)
34       : max_parallelism_(max_parallelism), shape_size_(shape_size) {}
~SimpleCostModel()35   ~SimpleCostModel() override {}
36 
GetParallelTaskCount(HloInstruction * instruction)37   int64 GetParallelTaskCount(HloInstruction* instruction) override {
38     // Simple cost model based on hlo size and typical L2 cache size.
39     const int64 instruction_cost = shape_size_(instruction->shape());
40     const int64 min_cost_per_thread = 256LL << 10;  // 256KB L2 Cache size.
41     // Return target parallel task count in [1, max_parallelism_].
42     return std::min(max_parallelism_,
43                     std::max(int64{1}, instruction_cost / min_cost_per_thread));
44   }
45 
46  private:
47   const int64 max_parallelism_;
48   const HloCostAnalysis::ShapeSizeFunction shape_size_;
49 };
50 
51 class DefaultCostModel : public ParallelCostModel {
52  public:
DefaultCostModel(const int64 max_parallelism,const HloCostAnalysis::ShapeSizeFunction & shape_size,std::unique_ptr<HloCostAnalysis> cost_analysis)53   DefaultCostModel(const int64 max_parallelism,
54                    const HloCostAnalysis::ShapeSizeFunction& shape_size,
55                    std::unique_ptr<HloCostAnalysis> cost_analysis)
56       : max_parallelism_(max_parallelism),
57         shape_size_(shape_size),
58         cost_analysis_(std::move(cost_analysis)) {}
~DefaultCostModel()59   ~DefaultCostModel() override {}
60 
GetParallelTaskCount(HloInstruction * instruction)61   int64 GetParallelTaskCount(HloInstruction* instruction) override {
62     // Parameters for parallel task count computation.
63     int64 instruction_cost;
64     int64 min_cost_per_thread;
65     int64 max_parallelism;
66     // Calculate flops-to-bytes-ratio for 'instruction'.
67     const int64 bytes_accessed =
68         std::max(int64{1}, cost_analysis_->bytes_accessed(*instruction));
69     const float flops_to_bytes_ratio =
70         cost_analysis_->flop_count(*instruction) /
71         static_cast<float>(bytes_accessed);
72     // Check for I/O bound instructions.
73     if (flops_to_bytes_ratio <= 1.0) {
74       // Limit max parallelism for I/O bound instructions by assuming a
75       // sub-linear scaling function (fit based on empirical benchmark results).
76       // TODO(b/29630486) Develop system bandwidth model.
77       max_parallelism =
78           std::ceil(std::sqrt(tensorflow::port::NumSchedulableCPUs()));
79       // Use shape size instruction cost and L2 cache size min per-thread cost.
80       instruction_cost = shape_size_(instruction->shape());
81       min_cost_per_thread = 256LL << 10;  // 256KB L2 Cache size.
82     } else {
83       // Use max parallelism for compute bound instructions.
84       max_parallelism = max_parallelism_;
85       // Calculate the instruction cost in cycles.
86       // TODO(b/29630486) Improve on this linear cost model.
87       // Consider making 'min_cost_per_thread' be a function of the target
88       // bandwidth limit for instructions with low arithmetic complexity.
89       instruction_cost =
90           1 * cost_analysis_->flop_count(*instruction) +
91           2 * cost_analysis_->transcendental_count(*instruction) +
92           10 * cost_analysis_->bytes_accessed(*instruction);
93       // Minimum per-thread cost is 100us of work on a 2GHz core.
94       min_cost_per_thread = 100000;
95     }
96     // Return target parallel task count in [1, max_parallelism_].
97     return std::min(max_parallelism,
98                     std::max(int64{1}, instruction_cost / min_cost_per_thread));
99   }
100 
101  private:
102   const int64 max_parallelism_;
103   const HloCostAnalysis::ShapeSizeFunction shape_size_;
104   const std::unique_ptr<HloCostAnalysis> cost_analysis_;
105 };
106 
ParallelTaskAssignment(const int64 max_parallelism,const HloCostAnalysis::ShapeSizeFunction & shape_size,HloModule * module,const TargetMachineFeatures * target_machine_features)107 ParallelTaskAssignment::ParallelTaskAssignment(
108     const int64 max_parallelism,
109     const HloCostAnalysis::ShapeSizeFunction& shape_size, HloModule* module,
110     const TargetMachineFeatures* target_machine_features)
111     : target_machine_features_(*target_machine_features) {
112   VLOG(1) << "ParallelTaskAssignment max_parallelism: " << max_parallelism;
113   // Run cost analysis on 'module'.
114   auto cost_analysis = absl::make_unique<HloCostAnalysis>(shape_size);
115   HloComputation* computation = module->entry_computation();
116   Status status = computation->root_instruction()->Accept(cost_analysis.get());
117   if (status.ok()) {
118     // Set default cost model based on 'cost_analysis'.
119     cost_model_.reset(new DefaultCostModel(max_parallelism, shape_size,
120                                            std::move(cost_analysis)));
121   } else {
122     // Fall back to a simple cost model based on hlo size and L2 cache size.
123     // Note that HloCostAnalysis can returns an error status (likely because
124     // HLOs like CustomCall are not yet implemented in the HloCostAnalysis).
125     cost_model_.reset(new SimpleCostModel(max_parallelism, shape_size));
126   }
127 }
128 
GetTargetParallelTaskCount(HloInstruction * instruction)129 int64 ParallelTaskAssignment::GetTargetParallelTaskCount(
130     HloInstruction* instruction) {
131   // Currently, we do not assign parallel tasks to instructions with at least
132   // one of the following properties:
133   // *) Internal threading (library calls to kConv, kDot, kFft, kCustomCall).
134   // *) Emit custom loops (kSelectAndScatter).
135   // *) Operations that are not thread safe (like infeed and rng).
136   // *) Tuple-shaped.
137   // TODO(b/27458679) Parallelize instructions which are skipped here.
138   auto opcode = instruction->opcode();
139   if (opcode == HloOpcode::kParameter || opcode == HloOpcode::kConstant ||
140       opcode == HloOpcode::kCall || opcode == HloOpcode::kCustomCall ||
141       opcode == HloOpcode::kDot || opcode == HloOpcode::kSelectAndScatter ||
142       opcode == HloOpcode::kGetTupleElement || opcode == HloOpcode::kBitcast ||
143       opcode == HloOpcode::kFft || opcode == HloOpcode::kInfeed ||
144       opcode == HloOpcode::kOutfeed || opcode == HloOpcode::kRng ||
145       opcode == HloOpcode::kSort ||
146       (opcode == HloOpcode::kConvolution &&
147        PotentiallyImplementedAsEigenConvolution(*instruction,
148                                                 target_machine_features_)) ||
149       (opcode == HloOpcode::kFusion &&
150        instruction->fusion_kind() != HloInstruction::FusionKind::kLoop) ||
151       instruction->shape().IsTuple()) {
152     return 1;
153   }
154 
155   // Consult 'cost_model_' to compute target parallel task count.
156   return cost_model_->GetParallelTaskCount(instruction);
157 }
158 
Run(HloModule * module)159 StatusOr<bool> ParallelTaskAssigner::Run(HloModule* module) {
160   XLA_VLOG_LINES(2, "ParallelTaskAssigner ENTRY");
161   XLA_VLOG_LINES(3, module->ToString());
162   // Compute target parallel task counts for all instructions in 'module'.
163   HloToParallelTasks hlo_to_parallel_tasks;
164   ComputeTargetParallelTasks(module, &hlo_to_parallel_tasks);
165 
166   // Assign parallel tasks to target specific instructions in 'module'.
167   // TODO(b/27458679) Support inter-op parallelism.
168   bool changed = AssignParallelTasks(module, hlo_to_parallel_tasks);
169 
170   XLA_VLOG_LINES(2, "ParallelTaskAssigner EXIT");
171   XLA_VLOG_LINES(3, module->ToString());
172   return changed;
173 }
174 
AssignParallelTasks(HloModule * module,const HloToParallelTasks & hlo_to_parallel_tasks)175 bool ParallelTaskAssigner::AssignParallelTasks(
176     HloModule* module, const HloToParallelTasks& hlo_to_parallel_tasks) {
177   return AssignParallelTasksHelper(module, module->entry_computation(),
178                                    hlo_to_parallel_tasks);
179 }
180 
AssignParallelTasksHelper(HloModule * module,HloComputation * computation,const HloToParallelTasks & hlo_to_parallel_tasks)181 bool ParallelTaskAssigner::AssignParallelTasksHelper(
182     HloModule* module, HloComputation* computation,
183     const HloToParallelTasks& hlo_to_parallel_tasks) {
184   bool changed = false;
185   // Snapshot set of instructions because outlining modifies the set below.
186   std::vector<HloInstruction*> instructions(computation->instructions().begin(),
187                                             computation->instructions().end());
188   for (auto* instruction : instructions) {
189     // Assign parallel tasks to sub-computations for While and Call HLOs.
190     // TODO(b/27458679) Evaluate alternative intra-op parallelsim placement,
191     // and support other callable computations like reduce.
192     if (instruction->opcode() == HloOpcode::kWhile) {
193       changed |= AssignParallelTasksHelper(module, instruction->while_body(),
194                                            hlo_to_parallel_tasks);
195       continue;
196     } else if (instruction->opcode() == HloOpcode::kCall) {
197       changed |= AssignParallelTasksHelper(module, instruction->to_apply(),
198                                            hlo_to_parallel_tasks);
199       continue;
200     }
201     // Skip if no parallel tasks were computed in first pass.
202     auto it = hlo_to_parallel_tasks.find(instruction);
203     if (it == hlo_to_parallel_tasks.end()) {
204       continue;
205     }
206     // Get target parallel task count computed for 'instruction'.
207     const int64 target_parallel_task_count = (*it).second;
208     // Assign feasible dimension partitions (based on actual dimension sizes).
209     auto dim_partition_counts = ShapePartitionAssigner(instruction->shape())
210                                     .Run(target_parallel_task_count);
211     const int64 total_partition_count =
212         ShapePartitionAssigner::GetTotalPartitionCount(dim_partition_counts);
213     if (total_partition_count <= 1) {
214       // Feasible partition calculation resulting in no partitioning, so skip.
215       continue;
216     }
217 
218     // Outline 'instruction' in 'computation' for parallel task assignment.
219     auto* call = module->OutlineExpressionFromComputation(
220         {instruction}, absl::StrCat("parallel_", instruction->name()),
221         computation);
222 
223     // Set assigned dimension partitioning to 'instruction'.
224     auto* new_root = call->to_apply()->root_instruction();
225     new_root->set_outer_dimension_partitions(dim_partition_counts);
226 
227     VLOG(2) << "Assigned parallel task count: " << total_partition_count
228             << " to instruction: " << new_root->name()
229             << " parent: " << new_root->parent()->name();
230     changed = true;
231   }
232   return changed;
233 }
234 
ComputeTargetParallelTasks(HloModule * module,HloToParallelTasks * hlo_to_parallel_tasks)235 void ParallelTaskAssigner::ComputeTargetParallelTasks(
236     HloModule* module, HloToParallelTasks* hlo_to_parallel_tasks) {
237   ParallelTaskAssignment parallel_task_assignment(max_parallelism_,
238                                                   shape_size_function_, module,
239                                                   &target_machine_features_);
240 
241   // Compute parallel task counts for all instructions in 'module'.
242   for (auto* computation : module->computations()) {
243     if (computation->IsFusionComputation()) {
244       continue;
245     }
246     for (auto* instruction : computation->instructions()) {
247       // Query ParallelTaskAssignment for target parallel task count.
248       const int64 target_parallel_task_count =
249           parallel_task_assignment.GetTargetParallelTaskCount(instruction);
250       if (target_parallel_task_count > 1) {
251         hlo_to_parallel_tasks->insert(
252             {instruction, target_parallel_task_count});
253       }
254     }
255   }
256 }
257 
258 }  // namespace cpu
259 }  // namespace xla
260