• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/reduction_splitter.h"
17 
18 #include <algorithm>
19 
20 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
21 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
22 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
23 #include "tensorflow/compiler/xla/shape_util.h"
24 
25 namespace xla {
26 namespace gpu {
27 
28 class ReductionSplitterVisitor : public DfsHloRewriteVisitor {
29  public:
HandleReduce(HloInstruction * reduce)30   Status HandleReduce(HloInstruction *reduce) override {
31     VLOG(4) << "Input: " << reduce->ToString();
32 
33     // Reductions with contiguous dimensions are lowered to efficient code. No
34     // need to split such ops.
35     if (IsReductionFromOrToContiguousDimensions(*reduce)) {
36       return Status::OK();
37     }
38     if (reduce->dimensions().size() < 2) {
39       return Status::OK();
40     }
41     if (!reduce->shape().IsArray()) {
42       // TODO(cheshire): Handle variadic reduction.
43       return Status::OK();
44     }
45 
46     HloInstruction *operand = reduce->mutable_operand(0);
47     const Shape &shape = operand->shape();
48     CHECK(shape == LayoutUtil::GetWithDefaultLayout(shape))
49         << "Default layout should be enforced on reduction operand";
50     // Verify that contiguous dimensions have been grouped by the
51     // ReductionDimensionGrouper pass.
52     for (int64 i = 0; i < reduce->dimensions().size(); ++i) {
53       for (int64 j = i + 1; j < reduce->dimensions().size(); ++j) {
54         CHECK(abs(reduce->dimensions(i) - reduce->dimensions(j)) > 1)
55             << "Reduction dimensions must not be consecutive";
56       }
57     }
58 
59     // The reduce op has non-contiguous dimensions. Look for the dimension with
60     // the largest shape dimension. Reducing along this dimension first will
61     // reduce the output size most effectively.
62     int64 max_shape_dim = 0;
63     int64 max_reduce_dim = 0;
64     const auto &input_shape = reduce->operand(0)->shape();
65     for (int64 i = 0; i < reduce->dimensions().size(); ++i) {
66       if (input_shape.dimensions(reduce->dimensions(i)) > max_shape_dim) {
67         max_reduce_dim = reduce->dimensions(i);
68         max_shape_dim = input_shape.dimensions(max_reduce_dim);
69       }
70     }
71     // TODO(tjoerg): Run microbenchmarks to tune this threshold.
72     if (max_shape_dim < 128) {
73       return Status::OK();
74     }
75 
76     // Split the reduction into a pre-reduction and a final reduction.
77     VLOG(3) << "Splitting reduction " << reduce->name() << " at dimension "
78             << max_reduce_dim;
79     std::vector<int64> pre_reduce_dims;
80     pre_reduce_dims.push_back(max_reduce_dim);
81     std::vector<int64> pre_reduce_shape_dims(input_shape.dimensions().begin(),
82                                              input_shape.dimensions().end());
83     pre_reduce_shape_dims.erase(pre_reduce_shape_dims.begin() + max_reduce_dim);
84     Shape pre_reduce_shape = ShapeUtil::MakeShape(
85         reduce->shape().element_type(), pre_reduce_shape_dims);
86     std::unique_ptr<HloInstruction> pre_reduce = HloInstruction::CreateReduce(
87         pre_reduce_shape, reduce->mutable_operand(0),
88         reduce->mutable_operand(1), pre_reduce_dims, reduce->to_apply());
89     pre_reduce->set_metadata(reduce->metadata());
90 
91     std::vector<int64> final_reduce_dims(reduce->dimensions().begin(),
92                                          reduce->dimensions().end());
93     final_reduce_dims.erase(
94         std::remove(final_reduce_dims.begin(), final_reduce_dims.end(),
95                     max_reduce_dim),
96         final_reduce_dims.end());
97     for (int64 i = 0; i < final_reduce_dims.size(); ++i) {
98       if (final_reduce_dims[i] > max_reduce_dim) {
99         final_reduce_dims[i]--;
100       }
101     }
102     std::unique_ptr<HloInstruction> final_reduce = HloInstruction::CreateReduce(
103         reduce->shape(),
104         reduce->parent()->AddInstruction(std::move(pre_reduce)),
105         reduce->mutable_operand(1), final_reduce_dims, reduce->to_apply());
106     return ReplaceWithNewInstruction(reduce, std::move(final_reduce));
107   }
108 };
109 
Run(HloModule * module)110 StatusOr<bool> ReductionSplitter::Run(HloModule *module) {
111   TF_ASSIGN_OR_RETURN(bool changed,
112                       ReductionSplitterVisitor().RunOnModule(module));
113   return changed;
114 }
115 
116 }  // namespace gpu
117 }  // namespace xla
118