• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// RUN: mlir-opt %s -convert-scf-to-std -convert-std-to-llvm \
2// RUN: | mlir-cpu-runner -e main -entry-point-result=void \
3// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext,%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \
4// RUN: | FileCheck %s
5
6func private @print_memref_f32(memref<*xf32>) attributes { llvm.emit_c_interface }
7
8func @main() -> () {
9  %c0 = constant 0 : index
10  %c1 = constant 1 : index
11
12  // Initialize input.
13  %input = alloc() : memref<2x3xf32>
14  %dim_x = dim %input, %c0 : memref<2x3xf32>
15  %dim_y = dim %input, %c1 : memref<2x3xf32>
16  scf.parallel (%i, %j) = (%c0, %c0) to (%dim_x, %dim_y) step (%c1, %c1) {
17    %prod = muli %i,  %dim_y : index
18    %val = addi %prod, %j : index
19    %val_i64 = index_cast %val : index to i64
20    %val_f32 = sitofp %val_i64 : i64 to f32
21    store %val_f32, %input[%i, %j] : memref<2x3xf32>
22  }
23  %unranked_input = memref_cast %input : memref<2x3xf32> to memref<*xf32>
24  call @print_memref_f32(%unranked_input) : (memref<*xf32>) -> ()
25  // CHECK: rank = 2 offset = 0 sizes = [2, 3] strides = [3, 1]
26  // CHECK-NEXT: [0,   1,   2]
27  // CHECK-NEXT: [3,   4,   5]
28
29  // Test cases.
30  call @cast_ranked_memref_to_static_shape(%input) : (memref<2x3xf32>) -> ()
31  call @cast_ranked_memref_to_dynamic_shape(%input) : (memref<2x3xf32>) -> ()
32  call @cast_unranked_memref_to_static_shape(%input) : (memref<2x3xf32>) -> ()
33  call @cast_unranked_memref_to_dynamic_shape(%input) : (memref<2x3xf32>) -> ()
34  return
35}
36
37func @cast_ranked_memref_to_static_shape(%input : memref<2x3xf32>) {
38  %output = memref_reinterpret_cast %input to
39           offset: [0], sizes: [6, 1], strides: [1, 1]
40           : memref<2x3xf32> to memref<6x1xf32>
41
42  %unranked_output = memref_cast %output
43      : memref<6x1xf32> to memref<*xf32>
44  call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
45  // CHECK: rank = 2 offset = 0 sizes = [6, 1] strides = [1, 1] data =
46  // CHECK-NEXT: [0],
47  // CHECK-NEXT: [1],
48  // CHECK-NEXT: [2],
49  // CHECK-NEXT: [3],
50  // CHECK-NEXT: [4],
51  // CHECK-NEXT: [5]
52  return
53}
54
55func @cast_ranked_memref_to_dynamic_shape(%input : memref<2x3xf32>) {
56  %c0 = constant 0 : index
57  %c1 = constant 1 : index
58  %c6 = constant 6 : index
59  %output = memref_reinterpret_cast %input to
60           offset: [%c0], sizes: [%c1, %c6], strides: [%c6, %c1]
61           : memref<2x3xf32> to memref<?x?xf32, offset: ?, strides: [?, ?]>
62
63  %unranked_output = memref_cast %output
64      : memref<?x?xf32, offset: ?, strides: [?, ?]> to memref<*xf32>
65  call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
66  // CHECK: rank = 2 offset = 0 sizes = [1, 6] strides = [6, 1] data =
67  // CHECK-NEXT: [0,   1,   2,   3,   4,   5]
68  return
69}
70
71func @cast_unranked_memref_to_static_shape(%input : memref<2x3xf32>) {
72  %unranked_input = memref_cast %input : memref<2x3xf32> to memref<*xf32>
73  %output = memref_reinterpret_cast %unranked_input to
74           offset: [0], sizes: [6, 1], strides: [1, 1]
75           : memref<*xf32> to memref<6x1xf32>
76
77  %unranked_output = memref_cast %output
78      : memref<6x1xf32> to memref<*xf32>
79  call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
80  // CHECK: rank = 2 offset = 0 sizes = [6, 1] strides = [1, 1] data =
81  // CHECK-NEXT: [0],
82  // CHECK-NEXT: [1],
83  // CHECK-NEXT: [2],
84  // CHECK-NEXT: [3],
85  // CHECK-NEXT: [4],
86  // CHECK-NEXT: [5]
87  return
88}
89
90func @cast_unranked_memref_to_dynamic_shape(%input : memref<2x3xf32>) {
91  %unranked_input = memref_cast %input : memref<2x3xf32> to memref<*xf32>
92  %c0 = constant 0 : index
93  %c1 = constant 1 : index
94  %c6 = constant 6 : index
95  %output = memref_reinterpret_cast %unranked_input to
96           offset: [%c0], sizes: [%c1, %c6], strides: [%c6, %c1]
97           : memref<*xf32> to memref<?x?xf32, offset: ?, strides: [?, ?]>
98
99  %unranked_output = memref_cast %output
100      : memref<?x?xf32, offset: ?, strides: [?, ?]> to memref<*xf32>
101  call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
102  // CHECK: rank = 2 offset = 0 sizes = [1, 6] strides = [6, 1] data =
103  // CHECK-NEXT: [0,   1,   2,   3,   4,   5]
104  return
105}
106