• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// RUN: mlir-hlo-opt -mhlo-legalize-einsum-to-dot-general %s -o - | FileCheck %s
2
3func @einsum_diag(%arg0: tensor<6x6xf32>) -> tensor<6xf32> {
4  %0 = mhlo.constant dense<1.000000e+00> : tensor<f32>
5  %1 = "mhlo.einsum"(%0, %arg0) {einsum_config = ",ii->i"} : (tensor<f32>, tensor<6x6xf32>) -> tensor<6xf32>
6  return %1 : tensor<6xf32>
7}
8// CHECK-LABEL: func @einsum_diag
9// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9_]+]]
10// CHECK:         %[[CST:.+]] = mhlo.constant dense<{{.*}} : tensor<f32>
11// CHECK:         %{{.+}} = "mhlo.dot_general"(%[[CST]], %[[ARG0]])
12// CHECK-SAME:      dot_dimension_numbers = {
13// CHECK-SAME:        lhs_batching_dimensions = dense<> : tensor<0xi64>,
14// CHECK-SAME:        lhs_contracting_dimensions = dense<> : tensor<0xi64>,
15// CHECK-SAME:        rhs_batching_dimensions = dense<> : tensor<0xi64>,
16// CHECK-SAME:        rhs_contracting_dimensions = dense<> : tensor<0xi64>}
17// CHECK-SAME:    : (tensor<f32>, tensor<6x6xf32>) -> tensor<6xf32>
18
19func @einsum_batched_matrix_high_rank_vector_mul(%arg0: tensor<8x2x6xf32>, %arg1: tensor<8x5x3x6xf32>) -> tensor<8x5x3x2xf32> {
20  %0 = "mhlo.einsum"(%arg0, %arg1) {einsum_config = "bxy,bijy->bijx"} : (tensor<8x2x6xf32>, tensor<8x5x3x6xf32>) -> tensor<8x5x3x2xf32>
21  return %0 : tensor<8x5x3x2xf32>
22}
23// CHECK-LABEL: func @einsum_batched_matrix_high_rank_vector_mul
24// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9_]+]]
25// CHECK-SAME:    %[[ARG1:[a-zA-Z0-9_]+]]
26// CHECK:         %{{.+}} = "mhlo.dot_general"(%[[ARG0]], %[[ARG1]])
27// CHECK-SAME:      dot_dimension_numbers = {
28// CHECK-SAME:        lhs_batching_dimensions = dense<0> : tensor<1xi64>,
29// CHECK-SAME:        lhs_contracting_dimensions = dense<2> : tensor<1xi64>,
30// CHECK-SAME:        rhs_batching_dimensions = dense<0> : tensor<1xi64>,
31// CHECK-SAME:        rhs_contracting_dimensions = dense<3> : tensor<1xi64>}
32// CHECK-SAME:    : (tensor<8x2x6xf32>, tensor<8x5x3x6xf32>) -> tensor<8x5x3x2xf32>
33
34func @matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
35  %0 = "mhlo.einsum"(%arg0, %arg1) {einsum_config = "ij,jk->ik"} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
36  return %0 : tensor<?x?xf32>
37}
38// CHECK-LABEL: func @matmul
39// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9_]+]]
40// CHECK-SAME:    %[[ARG1:[a-zA-Z0-9_]+]]
41// CHECK:         %{{.+}} = "mhlo.dot_general"(%[[ARG0]], %[[ARG1]])
42// CHECK-SAME:      dot_dimension_numbers = {
43// CHECK-SAME:        lhs_batching_dimensions = dense<> : tensor<0xi64>,
44// CHECK-SAME:        lhs_contracting_dimensions = dense<1> : tensor<1xi64>,
45// CHECK-SAME:        rhs_batching_dimensions = dense<> : tensor<0xi64>,
46// CHECK-SAME:        rhs_contracting_dimensions = dense<0> : tensor<1xi64>}
47// CHECK-SAME:    : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
48
49func @matvec(%arg0: tensor<?x?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
50  %0 = "mhlo.einsum"(%arg0, %arg1) {einsum_config = "ij,j->i"} : (tensor<?x?xf32>, tensor<?xf32>) -> tensor<?xf32>
51  return %0 : tensor<?xf32>
52}
53// CHECK-LABEL: func @matvec
54// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9_]+]]
55// CHECK-SAME:    %[[ARG1:[a-zA-Z0-9_]+]]
56// CHECK:         %{{.+}} = "mhlo.dot_general"(%[[ARG0]], %[[ARG1]])
57// CHECK-SAME:      dot_dimension_numbers = {
58// CHECK-SAME:        lhs_batching_dimensions = dense<> : tensor<0xi64>,
59// CHECK-SAME:        lhs_contracting_dimensions = dense<1> : tensor<1xi64>,
60// CHECK-SAME:        rhs_batching_dimensions = dense<> : tensor<0xi64>,
61// CHECK-SAME:        rhs_contracting_dimensions = dense<0> : tensor<1xi64>}
62// CHECK-SAME:    : (tensor<?x?xf32>, tensor<?xf32>) -> tensor<?xf32>
63
64func @dot(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<f32> {
65  %0 = "mhlo.einsum"(%arg0, %arg1) {einsum_config = "i,i->"} : (tensor<?xf32>, tensor<?xf32>) -> tensor<f32>
66  return %0 : tensor<f32>
67}
68// CHECK-LABEL: func @dot
69// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9_]+]]
70// CHECK-SAME:    %[[ARG1:[a-zA-Z0-9_]+]]
71// CHECK:         %{{.+}} = "mhlo.dot_general"(%[[ARG0]], %[[ARG1]])
72// CHECK-SAME:      dot_dimension_numbers = {
73// CHECK-SAME:        lhs_batching_dimensions = dense<> : tensor<0xi64>,
74// CHECK-SAME:        lhs_contracting_dimensions = dense<0> : tensor<1xi64>,
75// CHECK-SAME:        rhs_batching_dimensions = dense<> : tensor<0xi64>,
76// CHECK-SAME:        rhs_contracting_dimensions = dense<0> : tensor<1xi64>}
77// CHECK-SAME:    : (tensor<?xf32>, tensor<?xf32>) -> tensor<f32>
78
79