• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/cpu/shape_partition.h"
17 
18 namespace xla {
19 namespace cpu {
20 
Run(int64 target_partition_count)21 std::vector<int64> ShapePartitionAssigner::Run(int64 target_partition_count) {
22   // Gather outer-most dims where dim_size >= 'target_partition_count'.
23   // This may include the inner-dim as LLVM can vectorize loops with dynamic
24   // bounds.
25   std::vector<int64> outer_dims;
26   int64 outer_dim_size = 1;
27   // TODO(b/27458679) Consider reserving enough minor dimensions (based on
28   // target vector register width) to enable vector instructions.
29   for (int i = shape_.layout().minor_to_major_size() - 1; i >= 0; --i) {
30     const int64 dimension = shape_.layout().minor_to_major(i);
31     outer_dims.push_back(dimension);
32     outer_dim_size *= shape_.dimensions(dimension);
33     if (outer_dim_size >= target_partition_count) {
34       break;
35     }
36   }
37 
38   // Clip target partition count if outer dim size is insufficient to cover.
39   target_partition_count = std::min(outer_dim_size, target_partition_count);
40 
41   // Calculate the target number of partitions per-dimension, by factoring
42   // 'target_partition_count' into 'num_outer_dims' equal terms.
43   // EX:
44   // *) target_partition_count = 16
45   // *) out_dim_count = 2
46   // *) target_dim_partition_count = 16 ^ (1.0 / 2) == 4
47   const int64 target_dim_partition_count = std::pow(
48       static_cast<double>(target_partition_count), 1.0 / outer_dims.size());
49 
50   // Assign feasible dimension partitions based on 'target_dim_partition_count'
51   // and actual dimension sizes from 'shape_'.
52   std::vector<int64> dimension_partition_counts(outer_dims.size());
53   for (int64 i = 0; i < outer_dims.size(); ++i) {
54     dimension_partition_counts[i] =
55         std::min(static_cast<int64>(shape_.dimensions(outer_dims[i])),
56                  target_dim_partition_count);
57   }
58 
59   // Check if total partition count is below 'target_partition_count'.
60   // This can occur if some dimensions in 'shape_' are below the
61   // 'target_dim_partition_count' threshold.
62   if (GetTotalPartitionCount(dimension_partition_counts) <
63       target_partition_count) {
64     // Assign additional partitions (greedily to outer dimensions), if doing
65     // so would keep the total number of partitions <= 'target_partition_count',
66     // using one pass over 'dimension_partition_counts'.
67     for (int64 i = 0; i < dimension_partition_counts.size(); ++i) {
68       const int64 current_dim_partition_count = dimension_partition_counts[i];
69       const int64 other_dims_partition_count =
70           GetTotalPartitionCount(dimension_partition_counts) /
71           current_dim_partition_count;
72       // Constraint: (current + additional) * other <= target
73       // Calculate: additional = target / other - current
74       int64 additional_partition_count =
75           target_partition_count / other_dims_partition_count -
76           current_dim_partition_count;
77       // Clip 'additional_partition_count' by current dimension size.
78       additional_partition_count = std::min(
79           shape_.dimensions(outer_dims[i]) - dimension_partition_counts[i],
80           additional_partition_count);
81       if (additional_partition_count > 0) {
82         dimension_partition_counts[i] += additional_partition_count;
83       }
84     }
85   }
86 
87   return dimension_partition_counts;
88 }
89 
GetTotalPartitionCount(const std::vector<int64> & dimension_partition_counts)90 int64 ShapePartitionAssigner::GetTotalPartitionCount(
91     const std::vector<int64>& dimension_partition_counts) {
92   int64 total_partition_count = 1;
93   for (int64 dim_partition_count : dimension_partition_counts) {
94     total_partition_count *= dim_partition_count;
95   }
96   return total_partition_count;
97 }
98 
ShapePartitionIterator(const Shape & shape,const std::vector<int64> & dimension_partition_counts)99 ShapePartitionIterator::ShapePartitionIterator(
100     const Shape& shape, const std::vector<int64>& dimension_partition_counts)
101     : shape_(shape),
102       dimension_partition_counts_(dimension_partition_counts),
103       dimensions_(dimension_partition_counts_.size()),
104       dimension_partition_sizes_(dimension_partition_counts_.size()),
105       dimension_partition_strides_(dimension_partition_counts_.size()) {
106   // Store partitioned outer dimensions from 'shape_'.
107   for (int i = 0; i < dimensions_.size(); ++i) {
108     dimensions_[i] = shape_.layout().minor_to_major(
109         shape_.layout().minor_to_major_size() - 1 - i);
110   }
111 
112   // Calculate partition size for each dimension (note that the size of
113   // the last partition in each dimension may be different if the dimension
114   // size is not a multiple of partition size).
115   for (int i = 0; i < dimension_partition_sizes_.size(); ++i) {
116     const int64 dim_size = shape_.dimensions(dimensions_[i]);
117     dimension_partition_sizes_[i] =
118         std::max(int64{1}, dim_size / dimension_partition_counts_[i]);
119   }
120 
121   // Calculate the partition strides for each dimension.
122   dimension_partition_strides_[dimension_partition_strides_.size() - 1] = 1;
123   for (int i = dimension_partition_strides_.size() - 2; i >= 0; --i) {
124     dimension_partition_strides_[i] = dimension_partition_strides_[i + 1] *
125                                       dimension_partition_counts_[i + 1];
126   }
127 }
128 
GetPartition(int64 index) const129 std::vector<std::pair<int64, int64>> ShapePartitionIterator::GetPartition(
130     int64 index) const {
131   // Calculate and return the partition for 'index'.
132   // Returns for each dimension: (partition_start, partition_size).
133   std::vector<std::pair<int64, int64>> partition(dimensions_.size());
134   for (int64 i = 0; i < partition.size(); ++i) {
135     // Calculate the index for dimension 'i'.
136     const int64 partition_index = index / dimension_partition_strides_[i];
137     // Calculate dimension partition start at 'partition_index'.
138     partition[i].first = partition_index * dimension_partition_sizes_[i];
139     // Calculate dimension partition size (note that the last partition size
140     // may be adjusted if dimension size is not a multiple of partition size).
141     if (partition_index == dimension_partition_counts_[i] - 1) {
142       // Last partition in this dimension.
143       partition[i].second =
144           shape_.dimensions(dimensions_[i]) - partition[i].first;
145     } else {
146       partition[i].second = dimension_partition_sizes_[i];
147     }
148     CHECK_GT(partition[i].second, 0);
149     // Update index to remove contribution from current dimension.
150     index -= partition_index * dimension_partition_strides_[i];
151   }
152   return partition;
153 }
154 
GetTotalPartitionCount() const155 int64 ShapePartitionIterator::GetTotalPartitionCount() const {
156   return ShapePartitionAssigner::GetTotalPartitionCount(
157       dimension_partition_counts_);
158 }
159 
160 }  // namespace cpu
161 }  // namespace xla
162