1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/dot_as_convolution_util.h"
17 
18 #include "absl/types/optional.h"
19 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
20 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
21 #include "tensorflow/compiler/xla/service/shape_inference.h"
22 #include "tensorflow/compiler/xla/status_macros.h"
23 
24 namespace xla {
25 namespace dot_as_convolution_util {
26 
SpatialIsBatch(int64_t lhs_spatial_size,const WindowDimension & spatial_wd)27 SpatialBatchRepresentation SpatialIsBatch(int64_t lhs_spatial_size,
28                                           const WindowDimension& spatial_wd) {
29   if (lhs_spatial_size == spatial_wd.size() &&
30       lhs_spatial_size == spatial_wd.base_dilation() &&
31       ((std::max<int64_t>(1, lhs_spatial_size - 1) == spatial_wd.stride() &&
32         spatial_wd.window_dilation() == 1) ||
33        (std::max<int64_t>(1, lhs_spatial_size - 1) ==
34             spatial_wd.window_dilation() &&
35         spatial_wd.stride() == 1)) &&
36       spatial_wd.padding_high() == 0 && spatial_wd.padding_low() == 0 &&
37       !spatial_wd.window_reversal()) {
38     return SpatialBatchRepresentation::kUnpaddedVersion;
39   } else if (lhs_spatial_size == spatial_wd.size() &&
40              spatial_wd.padding_high() == lhs_spatial_size - 1 &&
41              spatial_wd.padding_low() == lhs_spatial_size - 1 &&
42              spatial_wd.window_reversal() &&
43              spatial_wd.window_dilation() == 1 &&
44              spatial_wd.stride() == lhs_spatial_size &&
45              spatial_wd.base_dilation() == lhs_spatial_size - 1) {
46     return SpatialBatchRepresentation::kPaddedVersion;
47   }
48   return SpatialBatchRepresentation::kNone;
49 }
50 
SpatialIsLhsNonContracting(int64_t rhs_spatial_size,const WindowDimension & spatial_wd)51 bool SpatialIsLhsNonContracting(int64_t rhs_spatial_size,
52                                 const WindowDimension& spatial_wd) {
53   return spatial_wd.stride() == 1 && spatial_wd.window_dilation() == 1 &&
54          spatial_wd.base_dilation() == 1 && rhs_spatial_size == 1 &&
55          spatial_wd.size() == 1 && spatial_wd.padding_high() == 0 &&
56          spatial_wd.padding_low() == 0 && !spatial_wd.window_reversal();
57 }
58 
SpatialIsRhsNonContracting(int64_t lhs_spatial_size,int64_t rhs_spatial_size,const WindowDimension & spatial_wd)59 bool SpatialIsRhsNonContracting(int64_t lhs_spatial_size,
60                                 int64_t rhs_spatial_size,
61                                 const WindowDimension& spatial_wd) {
62   return spatial_wd.stride() == 1 && spatial_wd.window_dilation() == 1 &&
63          spatial_wd.base_dilation() == 1 && lhs_spatial_size == 1 &&
64          spatial_wd.size() == rhs_spatial_size &&
65          spatial_wd.padding_high() == rhs_spatial_size - 1 &&
66          spatial_wd.padding_low() == rhs_spatial_size - 1 &&
67          spatial_wd.window_reversal();
68 }
69 
SpatialIsContracting(int64_t lhs_spatial_size,int64_t rhs_spatial_size,const WindowDimension & spatial_wd)70 bool SpatialIsContracting(int64_t lhs_spatial_size, int64_t rhs_spatial_size,
71                           const WindowDimension& spatial_wd) {
72   return lhs_spatial_size == spatial_wd.size() &&
73          spatial_wd.base_dilation() == 1 && spatial_wd.window_dilation() == 1 &&
74          spatial_wd.padding_high() == 0 && spatial_wd.padding_low() == 0 &&
75          !spatial_wd.window_reversal();
76 }
77 
ParseConvolutionDimsInfo(const HloInstruction * conv)78 /* static */ DotConvolutionDimsInfo ParseConvolutionDimsInfo(
79     const HloInstruction* conv) {
80   CHECK_EQ(conv->opcode(), HloOpcode::kConvolution);
81   const auto& conv_dims = conv->convolution_dimension_numbers();
82   DotConvolutionDimsInfo dims;
83   dims.lhs_non_contracting_dims.push_back(
84       {conv_dims.input_batch_dimension(), -1,
85        conv_dims.output_batch_dimension(), -1});
86   dims.rhs_non_contracting_dims.push_back(
87       {-1, conv_dims.kernel_output_feature_dimension(),
88        conv_dims.output_feature_dimension(), -1});
89   dims.contracting_dims.push_back({conv_dims.input_feature_dimension(),
90                                    conv_dims.kernel_input_feature_dimension(),
91                                    -1, -1});
92 
93   for (int64_t i = 0; i < conv_dims.input_spatial_dimensions_size(); ++i) {
94     int64_t lhs = conv_dims.input_spatial_dimensions(i);
95     int64_t lhs_size = conv->operand(0)->shape().dimensions(lhs);
96     int64_t rhs = conv_dims.kernel_spatial_dimensions(i);
97     int64_t rhs_size = conv->operand(1)->shape().dimensions(rhs);
98     int64_t output = conv_dims.output_spatial_dimensions(i);
99     const auto& wd = conv->window().dimensions(i);
100     if (SpatialIsBatch(lhs_size, wd) != SpatialBatchRepresentation::kNone) {
101       dims.batch_dims.push_back({lhs, rhs, output, i});
102     } else if (lhs_size == wd.size() && wd.base_dilation() == 1 &&
103                wd.window_dilation() == 1 && wd.padding_high() == 0 &&
104                wd.padding_low() == 0 && !wd.window_reversal()) {
105       // A contracting dimension be represented as a spatial dimension with
106       // window size C (contracting dimension size). Stride can be any size
107       // since there is only one window.
108       dims.contracting_dims.push_back({lhs, rhs, output, i});
109     } else if (wd.stride() == 1 && wd.window_dilation() == 1 &&
110                wd.base_dilation() == 1) {
111       if (rhs_size == 1 && wd.size() == 1 && wd.padding_high() == 0 &&
112           wd.padding_low() == 0 && !wd.window_reversal()) {
113         // A LHS non-contracting dimension can be represented as a spatial
114         // dimension with window size 1.
115         dims.lhs_non_contracting_dims.push_back({lhs, rhs, output, i});
116       } else if (lhs_size == 1 && wd.size() == rhs_size &&
117                  wd.padding_high() == rhs_size - 1 &&
118                  wd.padding_low() == rhs_size - 1 && wd.window_reversal()) {
119         // A RHS non-contracting dimension can be represented as a spatial
120         // dimension with window size N (non-contracting dimension size), low
121         // padding N - 1,  high padding N - 1 and window reversal.
122         dims.rhs_non_contracting_dims.push_back({lhs, rhs, output, i});
123       } else {
124         dims.conv_spatial_dims.push_back({lhs, rhs, output, i});
125       }
126     } else {
127       dims.conv_spatial_dims.push_back({lhs, rhs, output, i});
128     }
129   }
130 
131   return dims;
132 }
133 
134 StatusOr<std::unique_ptr<HloInstruction>>
CreateShardedConvForDotGeneralConvolution(const HloInstruction & conv,const DotConvolutionDimsInfo & dot_dnums,HloInstruction * sharded_lhs_hlo,HloInstruction * sharded_rhs_hlo)135 CreateShardedConvForDotGeneralConvolution(
136     const HloInstruction& conv, const DotConvolutionDimsInfo& dot_dnums,
137     HloInstruction* sharded_lhs_hlo, HloInstruction* sharded_rhs_hlo) {
138   CHECK_EQ(conv.opcode(), HloOpcode::kConvolution);
139   const auto& conv_dnums = conv.convolution_dimension_numbers();
140   auto window = conv.window();
141   for (const auto& dim : dot_dnums.batch_dims) {
142     auto wd = window.mutable_dimensions(dim.spatial_dim);
143     wd->set_size(sharded_lhs_hlo->shape().dimensions(
144         conv_dnums.input_spatial_dimensions(dim.spatial_dim)));
145     wd->set_stride(std::max<int64>(1, wd->size() - 1));
146     wd->set_base_dilation(wd->size());
147   }
148   for (const auto& dim : dot_dnums.contracting_dims) {
149     if (dim.spatial_dim < 0) {
150       continue;
151     }
152     auto wd = window.mutable_dimensions(dim.spatial_dim);
153     wd->set_size(sharded_lhs_hlo->shape().dimensions(
154         conv_dnums.input_spatial_dimensions(dim.spatial_dim)));
155   }
156   for (const auto& dim : dot_dnums.rhs_non_contracting_dims) {
157     if (dim.spatial_dim < 0) {
158       continue;
159     }
160     auto wd = window.mutable_dimensions(dim.spatial_dim);
161     wd->set_size(sharded_rhs_hlo->shape().dimensions(
162         conv_dnums.kernel_spatial_dimensions(dim.spatial_dim)));
163     wd->set_padding_high(wd->size() - 1);
164     wd->set_padding_low(wd->size() - 1);
165   }
166   TF_ASSIGN_OR_RETURN(
167       Shape sharded_conv_shape,
168       ShapeInference::InferConvolveShape(
169           sharded_lhs_hlo->shape(), sharded_rhs_hlo->shape(),
170           /*feature_group_count=*/conv.feature_group_count(),
171           /*batch_group_count=*/conv.batch_group_count(), window, conv_dnums,
172           /*preferred_element_type=*/conv.shape().element_type()));
173   *sharded_conv_shape.mutable_layout() = conv.shape().layout();
174   return HloInstruction::CreateConvolve(
175       sharded_conv_shape, sharded_lhs_hlo, sharded_rhs_hlo,
176       /*feature_group_count=*/conv.feature_group_count(),
177       /*batch_group_count=*/conv.batch_group_count(), window, conv_dnums,
178       conv.precision_config());
179 }
180 
ParseDotGeneralFromDot(const HloInstruction * dot)181 DotConvolutionDimsInfo ParseDotGeneralFromDot(const HloInstruction* dot) {
182   const auto& dot_dim_numbs = dot->dot_dimension_numbers();
183   dot_as_convolution_util::DotConvolutionDimsInfo dnums;
184   for (int64_t i = 0; i < dot_dim_numbs.lhs_batch_dimensions().size(); ++i) {
185     dnums.batch_dims.emplace_back();
186     dnums.batch_dims.back().lhs = dot_dim_numbs.lhs_batch_dimensions(i);
187     dnums.batch_dims.back().rhs = dot_dim_numbs.rhs_batch_dimensions(i);
188     dnums.batch_dims.back().output = i;
189     dnums.batch_dims.back().spatial_dim = -1;
190   }
191   for (int64_t i = 0; i < dot_dim_numbs.lhs_contracting_dimensions().size();
192        ++i) {
193     dnums.contracting_dims.emplace_back();
194     dnums.contracting_dims.back().lhs =
195         dot_dim_numbs.lhs_contracting_dimensions(i);
196     dnums.contracting_dims.back().rhs =
197         dot_dim_numbs.rhs_contracting_dimensions(i);
198     dnums.contracting_dims.back().output = -1;
199     dnums.contracting_dims.back().spatial_dim = -1;
200   }
201   for (int64_t i = 0; i < dot->operand(0)->shape().rank(); ++i) {
202     if (!absl::c_linear_search(dot_dim_numbs.lhs_batch_dimensions(), i) &&
203         !absl::c_linear_search(dot_dim_numbs.lhs_contracting_dimensions(), i)) {
204       dnums.lhs_non_contracting_dims.emplace_back();
205       dnums.lhs_non_contracting_dims.back().lhs = i;
206       dnums.lhs_non_contracting_dims.back().rhs = -1;
207       dnums.lhs_non_contracting_dims.back().output =
208           dot_dim_numbs.lhs_batch_dimensions_size() +
209           dnums.lhs_non_contracting_dims.size() - 1;
210       dnums.lhs_non_contracting_dims.back().spatial_dim = -1;
211     }
212   }
213   for (int64_t i = 0; i < dot->operand(1)->shape().rank(); ++i) {
214     if (!absl::c_linear_search(dot_dim_numbs.rhs_batch_dimensions(), i) &&
215         !absl::c_linear_search(dot_dim_numbs.rhs_contracting_dimensions(), i)) {
216       dnums.rhs_non_contracting_dims.emplace_back();
217       dnums.rhs_non_contracting_dims.back().lhs = -1;
218       dnums.rhs_non_contracting_dims.back().rhs = i;
219       dnums.rhs_non_contracting_dims.back().output =
220           dot_dim_numbs.lhs_batch_dimensions_size() +
221           dnums.lhs_non_contracting_dims.size() +
222           dnums.rhs_non_contracting_dims.size() - 1;
223       dnums.rhs_non_contracting_dims.back().spatial_dim = -1;
224     }
225   }
226   return dnums;
227 }
228 
229 }  // namespace dot_as_convolution_util
230 }  // namespace xla
231