1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/dot_as_convolution_util.h"
17
18 #include "absl/types/optional.h"
19 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
20 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
21 #include "tensorflow/compiler/xla/service/shape_inference.h"
22 #include "tensorflow/compiler/xla/status_macros.h"
23
24 namespace xla {
25 namespace dot_as_convolution_util {
26
SpatialIsBatch(int64_t lhs_spatial_size,const WindowDimension & spatial_wd)27 SpatialBatchRepresentation SpatialIsBatch(int64_t lhs_spatial_size,
28 const WindowDimension& spatial_wd) {
29 if (lhs_spatial_size == spatial_wd.size() &&
30 lhs_spatial_size == spatial_wd.base_dilation() &&
31 ((std::max<int64_t>(1, lhs_spatial_size - 1) == spatial_wd.stride() &&
32 spatial_wd.window_dilation() == 1) ||
33 (std::max<int64_t>(1, lhs_spatial_size - 1) ==
34 spatial_wd.window_dilation() &&
35 spatial_wd.stride() == 1)) &&
36 spatial_wd.padding_high() == 0 && spatial_wd.padding_low() == 0 &&
37 !spatial_wd.window_reversal()) {
38 return SpatialBatchRepresentation::kUnpaddedVersion;
39 } else if (lhs_spatial_size == spatial_wd.size() &&
40 spatial_wd.padding_high() == lhs_spatial_size - 1 &&
41 spatial_wd.padding_low() == lhs_spatial_size - 1 &&
42 spatial_wd.window_reversal() &&
43 spatial_wd.window_dilation() == 1 &&
44 spatial_wd.stride() == lhs_spatial_size &&
45 spatial_wd.base_dilation() == lhs_spatial_size - 1) {
46 return SpatialBatchRepresentation::kPaddedVersion;
47 }
48 return SpatialBatchRepresentation::kNone;
49 }
50
SpatialIsLhsNonContracting(int64_t rhs_spatial_size,const WindowDimension & spatial_wd)51 bool SpatialIsLhsNonContracting(int64_t rhs_spatial_size,
52 const WindowDimension& spatial_wd) {
53 return spatial_wd.stride() == 1 && spatial_wd.window_dilation() == 1 &&
54 spatial_wd.base_dilation() == 1 && rhs_spatial_size == 1 &&
55 spatial_wd.size() == 1 && spatial_wd.padding_high() == 0 &&
56 spatial_wd.padding_low() == 0 && !spatial_wd.window_reversal();
57 }
58
SpatialIsRhsNonContracting(int64_t lhs_spatial_size,int64_t rhs_spatial_size,const WindowDimension & spatial_wd)59 bool SpatialIsRhsNonContracting(int64_t lhs_spatial_size,
60 int64_t rhs_spatial_size,
61 const WindowDimension& spatial_wd) {
62 return spatial_wd.stride() == 1 && spatial_wd.window_dilation() == 1 &&
63 spatial_wd.base_dilation() == 1 && lhs_spatial_size == 1 &&
64 spatial_wd.size() == rhs_spatial_size &&
65 spatial_wd.padding_high() == rhs_spatial_size - 1 &&
66 spatial_wd.padding_low() == rhs_spatial_size - 1 &&
67 spatial_wd.window_reversal();
68 }
69
SpatialIsContracting(int64_t lhs_spatial_size,int64_t rhs_spatial_size,const WindowDimension & spatial_wd)70 bool SpatialIsContracting(int64_t lhs_spatial_size, int64_t rhs_spatial_size,
71 const WindowDimension& spatial_wd) {
72 return lhs_spatial_size == spatial_wd.size() &&
73 spatial_wd.base_dilation() == 1 && spatial_wd.window_dilation() == 1 &&
74 spatial_wd.padding_high() == 0 && spatial_wd.padding_low() == 0 &&
75 !spatial_wd.window_reversal();
76 }
77
ParseConvolutionDimsInfo(const HloInstruction * conv)78 /* static */ DotConvolutionDimsInfo ParseConvolutionDimsInfo(
79 const HloInstruction* conv) {
80 CHECK_EQ(conv->opcode(), HloOpcode::kConvolution);
81 const auto& conv_dims = conv->convolution_dimension_numbers();
82 DotConvolutionDimsInfo dims;
83 dims.lhs_non_contracting_dims.push_back(
84 {conv_dims.input_batch_dimension(), -1,
85 conv_dims.output_batch_dimension(), -1});
86 dims.rhs_non_contracting_dims.push_back(
87 {-1, conv_dims.kernel_output_feature_dimension(),
88 conv_dims.output_feature_dimension(), -1});
89 dims.contracting_dims.push_back({conv_dims.input_feature_dimension(),
90 conv_dims.kernel_input_feature_dimension(),
91 -1, -1});
92
93 for (int64_t i = 0; i < conv_dims.input_spatial_dimensions_size(); ++i) {
94 int64_t lhs = conv_dims.input_spatial_dimensions(i);
95 int64_t lhs_size = conv->operand(0)->shape().dimensions(lhs);
96 int64_t rhs = conv_dims.kernel_spatial_dimensions(i);
97 int64_t rhs_size = conv->operand(1)->shape().dimensions(rhs);
98 int64_t output = conv_dims.output_spatial_dimensions(i);
99 const auto& wd = conv->window().dimensions(i);
100 if (SpatialIsBatch(lhs_size, wd) != SpatialBatchRepresentation::kNone) {
101 dims.batch_dims.push_back({lhs, rhs, output, i});
102 } else if (lhs_size == wd.size() && wd.base_dilation() == 1 &&
103 wd.window_dilation() == 1 && wd.padding_high() == 0 &&
104 wd.padding_low() == 0 && !wd.window_reversal()) {
105 // A contracting dimension be represented as a spatial dimension with
106 // window size C (contracting dimension size). Stride can be any size
107 // since there is only one window.
108 dims.contracting_dims.push_back({lhs, rhs, output, i});
109 } else if (wd.stride() == 1 && wd.window_dilation() == 1 &&
110 wd.base_dilation() == 1) {
111 if (rhs_size == 1 && wd.size() == 1 && wd.padding_high() == 0 &&
112 wd.padding_low() == 0 && !wd.window_reversal()) {
113 // A LHS non-contracting dimension can be represented as a spatial
114 // dimension with window size 1.
115 dims.lhs_non_contracting_dims.push_back({lhs, rhs, output, i});
116 } else if (lhs_size == 1 && wd.size() == rhs_size &&
117 wd.padding_high() == rhs_size - 1 &&
118 wd.padding_low() == rhs_size - 1 && wd.window_reversal()) {
119 // A RHS non-contracting dimension can be represented as a spatial
120 // dimension with window size N (non-contracting dimension size), low
121 // padding N - 1, high padding N - 1 and window reversal.
122 dims.rhs_non_contracting_dims.push_back({lhs, rhs, output, i});
123 } else {
124 dims.conv_spatial_dims.push_back({lhs, rhs, output, i});
125 }
126 } else {
127 dims.conv_spatial_dims.push_back({lhs, rhs, output, i});
128 }
129 }
130
131 return dims;
132 }
133
134 StatusOr<std::unique_ptr<HloInstruction>>
CreateShardedConvForDotGeneralConvolution(const HloInstruction & conv,const DotConvolutionDimsInfo & dot_dnums,HloInstruction * sharded_lhs_hlo,HloInstruction * sharded_rhs_hlo)135 CreateShardedConvForDotGeneralConvolution(
136 const HloInstruction& conv, const DotConvolutionDimsInfo& dot_dnums,
137 HloInstruction* sharded_lhs_hlo, HloInstruction* sharded_rhs_hlo) {
138 CHECK_EQ(conv.opcode(), HloOpcode::kConvolution);
139 const auto& conv_dnums = conv.convolution_dimension_numbers();
140 auto window = conv.window();
141 for (const auto& dim : dot_dnums.batch_dims) {
142 auto wd = window.mutable_dimensions(dim.spatial_dim);
143 wd->set_size(sharded_lhs_hlo->shape().dimensions(
144 conv_dnums.input_spatial_dimensions(dim.spatial_dim)));
145 wd->set_stride(std::max<int64>(1, wd->size() - 1));
146 wd->set_base_dilation(wd->size());
147 }
148 for (const auto& dim : dot_dnums.contracting_dims) {
149 if (dim.spatial_dim < 0) {
150 continue;
151 }
152 auto wd = window.mutable_dimensions(dim.spatial_dim);
153 wd->set_size(sharded_lhs_hlo->shape().dimensions(
154 conv_dnums.input_spatial_dimensions(dim.spatial_dim)));
155 }
156 for (const auto& dim : dot_dnums.rhs_non_contracting_dims) {
157 if (dim.spatial_dim < 0) {
158 continue;
159 }
160 auto wd = window.mutable_dimensions(dim.spatial_dim);
161 wd->set_size(sharded_rhs_hlo->shape().dimensions(
162 conv_dnums.kernel_spatial_dimensions(dim.spatial_dim)));
163 wd->set_padding_high(wd->size() - 1);
164 wd->set_padding_low(wd->size() - 1);
165 }
166 TF_ASSIGN_OR_RETURN(
167 Shape sharded_conv_shape,
168 ShapeInference::InferConvolveShape(
169 sharded_lhs_hlo->shape(), sharded_rhs_hlo->shape(),
170 /*feature_group_count=*/conv.feature_group_count(),
171 /*batch_group_count=*/conv.batch_group_count(), window, conv_dnums,
172 /*preferred_element_type=*/conv.shape().element_type()));
173 *sharded_conv_shape.mutable_layout() = conv.shape().layout();
174 return HloInstruction::CreateConvolve(
175 sharded_conv_shape, sharded_lhs_hlo, sharded_rhs_hlo,
176 /*feature_group_count=*/conv.feature_group_count(),
177 /*batch_group_count=*/conv.batch_group_count(), window, conv_dnums,
178 conv.precision_config());
179 }
180
ParseDotGeneralFromDot(const HloInstruction * dot)181 DotConvolutionDimsInfo ParseDotGeneralFromDot(const HloInstruction* dot) {
182 const auto& dot_dim_numbs = dot->dot_dimension_numbers();
183 dot_as_convolution_util::DotConvolutionDimsInfo dnums;
184 for (int64_t i = 0; i < dot_dim_numbs.lhs_batch_dimensions().size(); ++i) {
185 dnums.batch_dims.emplace_back();
186 dnums.batch_dims.back().lhs = dot_dim_numbs.lhs_batch_dimensions(i);
187 dnums.batch_dims.back().rhs = dot_dim_numbs.rhs_batch_dimensions(i);
188 dnums.batch_dims.back().output = i;
189 dnums.batch_dims.back().spatial_dim = -1;
190 }
191 for (int64_t i = 0; i < dot_dim_numbs.lhs_contracting_dimensions().size();
192 ++i) {
193 dnums.contracting_dims.emplace_back();
194 dnums.contracting_dims.back().lhs =
195 dot_dim_numbs.lhs_contracting_dimensions(i);
196 dnums.contracting_dims.back().rhs =
197 dot_dim_numbs.rhs_contracting_dimensions(i);
198 dnums.contracting_dims.back().output = -1;
199 dnums.contracting_dims.back().spatial_dim = -1;
200 }
201 for (int64_t i = 0; i < dot->operand(0)->shape().rank(); ++i) {
202 if (!absl::c_linear_search(dot_dim_numbs.lhs_batch_dimensions(), i) &&
203 !absl::c_linear_search(dot_dim_numbs.lhs_contracting_dimensions(), i)) {
204 dnums.lhs_non_contracting_dims.emplace_back();
205 dnums.lhs_non_contracting_dims.back().lhs = i;
206 dnums.lhs_non_contracting_dims.back().rhs = -1;
207 dnums.lhs_non_contracting_dims.back().output =
208 dot_dim_numbs.lhs_batch_dimensions_size() +
209 dnums.lhs_non_contracting_dims.size() - 1;
210 dnums.lhs_non_contracting_dims.back().spatial_dim = -1;
211 }
212 }
213 for (int64_t i = 0; i < dot->operand(1)->shape().rank(); ++i) {
214 if (!absl::c_linear_search(dot_dim_numbs.rhs_batch_dimensions(), i) &&
215 !absl::c_linear_search(dot_dim_numbs.rhs_contracting_dimensions(), i)) {
216 dnums.rhs_non_contracting_dims.emplace_back();
217 dnums.rhs_non_contracting_dims.back().lhs = -1;
218 dnums.rhs_non_contracting_dims.back().rhs = i;
219 dnums.rhs_non_contracting_dims.back().output =
220 dot_dim_numbs.lhs_batch_dimensions_size() +
221 dnums.lhs_non_contracting_dims.size() +
222 dnums.rhs_non_contracting_dims.size() - 1;
223 dnums.rhs_non_contracting_dims.back().spatial_dim = -1;
224 }
225 }
226 return dnums;
227 }
228
229 } // namespace dot_as_convolution_util
230 } // namespace xla
231