1// RUN: mlir-hlo-opt -mhlo-legalize-einsum-to-dot-general %s -o - | FileCheck %s 2 3func @einsum_diag(%arg0: tensor<6x6xf32>) -> tensor<6xf32> { 4 %0 = mhlo.constant dense<1.000000e+00> : tensor<f32> 5 %1 = "mhlo.einsum"(%0, %arg0) {einsum_config = ",ii->i"} : (tensor<f32>, tensor<6x6xf32>) -> tensor<6xf32> 6 return %1 : tensor<6xf32> 7} 8// CHECK-LABEL: func @einsum_diag 9// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]] 10// CHECK: %[[CST:.+]] = mhlo.constant dense<{{.*}} : tensor<f32> 11// CHECK: %{{.+}} = "mhlo.dot_general"(%[[CST]], %[[ARG0]]) 12// CHECK-SAME: dot_dimension_numbers = { 13// CHECK-SAME: lhs_batching_dimensions = dense<> : tensor<0xi64>, 14// CHECK-SAME: lhs_contracting_dimensions = dense<> : tensor<0xi64>, 15// CHECK-SAME: rhs_batching_dimensions = dense<> : tensor<0xi64>, 16// CHECK-SAME: rhs_contracting_dimensions = dense<> : tensor<0xi64>} 17// CHECK-SAME: : (tensor<f32>, tensor<6x6xf32>) -> tensor<6xf32> 18 19func @einsum_batched_matrix_high_rank_vector_mul(%arg0: tensor<8x2x6xf32>, %arg1: tensor<8x5x3x6xf32>) -> tensor<8x5x3x2xf32> { 20 %0 = "mhlo.einsum"(%arg0, %arg1) {einsum_config = "bxy,bijy->bijx"} : (tensor<8x2x6xf32>, tensor<8x5x3x6xf32>) -> tensor<8x5x3x2xf32> 21 return %0 : tensor<8x5x3x2xf32> 22} 23// CHECK-LABEL: func @einsum_batched_matrix_high_rank_vector_mul 24// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]] 25// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]] 26// CHECK: %{{.+}} = "mhlo.dot_general"(%[[ARG0]], %[[ARG1]]) 27// CHECK-SAME: dot_dimension_numbers = { 28// CHECK-SAME: lhs_batching_dimensions = dense<0> : tensor<1xi64>, 29// CHECK-SAME: lhs_contracting_dimensions = dense<2> : tensor<1xi64>, 30// CHECK-SAME: rhs_batching_dimensions = dense<0> : tensor<1xi64>, 31// CHECK-SAME: rhs_contracting_dimensions = dense<3> : tensor<1xi64>} 32// CHECK-SAME: : (tensor<8x2x6xf32>, tensor<8x5x3x6xf32>) -> tensor<8x5x3x2xf32> 33 34func @matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> { 35 %0 = "mhlo.einsum"(%arg0, %arg1) {einsum_config = "ij,jk->ik"} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32> 36 return %0 : tensor<?x?xf32> 37} 38// CHECK-LABEL: func @matmul 39// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]] 40// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]] 41// CHECK: %{{.+}} = "mhlo.dot_general"(%[[ARG0]], %[[ARG1]]) 42// CHECK-SAME: dot_dimension_numbers = { 43// CHECK-SAME: lhs_batching_dimensions = dense<> : tensor<0xi64>, 44// CHECK-SAME: lhs_contracting_dimensions = dense<1> : tensor<1xi64>, 45// CHECK-SAME: rhs_batching_dimensions = dense<> : tensor<0xi64>, 46// CHECK-SAME: rhs_contracting_dimensions = dense<0> : tensor<1xi64>} 47// CHECK-SAME: : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32> 48 49func @matvec(%arg0: tensor<?x?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> { 50 %0 = "mhlo.einsum"(%arg0, %arg1) {einsum_config = "ij,j->i"} : (tensor<?x?xf32>, tensor<?xf32>) -> tensor<?xf32> 51 return %0 : tensor<?xf32> 52} 53// CHECK-LABEL: func @matvec 54// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]] 55// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]] 56// CHECK: %{{.+}} = "mhlo.dot_general"(%[[ARG0]], %[[ARG1]]) 57// CHECK-SAME: dot_dimension_numbers = { 58// CHECK-SAME: lhs_batching_dimensions = dense<> : tensor<0xi64>, 59// CHECK-SAME: lhs_contracting_dimensions = dense<1> : tensor<1xi64>, 60// CHECK-SAME: rhs_batching_dimensions = dense<> : tensor<0xi64>, 61// CHECK-SAME: rhs_contracting_dimensions = dense<0> : tensor<1xi64>} 62// CHECK-SAME: : (tensor<?x?xf32>, tensor<?xf32>) -> tensor<?xf32> 63 64func @dot(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<f32> { 65 %0 = "mhlo.einsum"(%arg0, %arg1) {einsum_config = "i,i->"} : (tensor<?xf32>, tensor<?xf32>) -> tensor<f32> 66 return %0 : tensor<f32> 67} 68// CHECK-LABEL: func @dot 69// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]] 70// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]] 71// CHECK: %{{.+}} = "mhlo.dot_general"(%[[ARG0]], %[[ARG1]]) 72// CHECK-SAME: dot_dimension_numbers = { 73// CHECK-SAME: lhs_batching_dimensions = dense<> : tensor<0xi64>, 74// CHECK-SAME: lhs_contracting_dimensions = dense<0> : tensor<1xi64>, 75// CHECK-SAME: rhs_batching_dimensions = dense<> : tensor<0xi64>, 76// CHECK-SAME: rhs_contracting_dimensions = dense<0> : tensor<1xi64>} 77// CHECK-SAME: : (tensor<?xf32>, tensor<?xf32>) -> tensor<f32> 78 79