• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// RUN: xla-opt %s -xla-legalize-xla-framework-to-llvm | FileCheck %s
2
3memref.global "private" constant @__constant_xf32 : memref<f32> = dense<42.0>
4
5func.func @buffer_type(%arg: !xla_framework.buffer {xla_framework.input_mapping = 0 : i64})
6                      attributes {xla_entry} {
7  %val = xla_framework.buffer_to_mem %arg : memref<f32>
8  %global = memref.get_global @__constant_xf32 : memref<f32>
9  memref.copy %global, %val : memref<f32> to memref<f32>
10  func.return
11}
12
13// CHECK-LABEL: @buffer_type
14// The following signature is always the same.
15// CHECK-SAME: %{{[^:]*}}: !llvm.ptr<i8>
16// CHECK-SAME: %{{[^:]*}}: !llvm.ptr<i8>
17// CHECK-SAME: %{{[^:]*}}: !llvm.ptr<ptr<i8>>
18// CHECK-SAME: %[[BUFFERS:.*]]: !llvm.ptr<ptr<i8>>
19// CHECK-SAME: %{{[^:]*}}: !llvm.ptr<i64>
20// CHECK-SAME: %{{[^:]*}}: !llvm.ptr<i64>) {
21// Retrieve pointer from the input as part of the function signature lowering.
22// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i64) : i32
23// CHECK: %[[PTRS:.*]] = llvm.getelementptr %[[BUFFERS]][%[[C0]]] : (!llvm.ptr<ptr<i8>>, i32) -> !llvm.ptr<ptr<i8>>
24// CHECK: %[[PTR0:.*]] = llvm.load %[[PTRS]] : !llvm.ptr<ptr<i8>>
25// CHECK: %[[INP0:.*]] = llvm.bitcast %[[PTR0]] : !llvm.ptr<i8> to !llvm.ptr<f32>
26// Create memref descriptor as the buffer_to_mem lowering.
27// CHECK: %[[MEMREF:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
28// CHECK: %[[MEMREF1:.*]] = llvm.insertvalue %[[INP0]], %[[MEMREF]][0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
29// CHECK: %[[MEMREF:.*]] = llvm.insertvalue %[[INP0]], %[[MEMREF1]][1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
30// CHECK: %[[C0_0:.*]] = llvm.mlir.constant(0 : index) : i64
31// CHECK: llvm.insertvalue %[[C0_0:.*]], %[[MEMREF:.*]][2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
32// No return values in this case
33// CHECK: return
34
35
36func.func @return_tuple(%result0: !xla_framework.buffer, %result1: !xla_framework.buffer)
37                      attributes {xla_entry, xla_framework.result_inner_mapping=[1,2], xla_framework.result_mapping=0} {
38  func.return
39}
40
41
42// CHECK-LABEL: @return_tuple
43// The following signature is always the same.
44// CHECK-SAME: %{{[^:]*}}: !llvm.ptr<i8>
45// CHECK-SAME: %{{[^:]*}}: !llvm.ptr<i8>
46// CHECK-SAME: %{{[^:]*}}: !llvm.ptr<ptr<i8>>
47// CHECK-SAME: %[[BUFFERS:.*]]: !llvm.ptr<ptr<i8>>
48// CHECK-SAME: %{{[^:]*}}: !llvm.ptr<i64>
49// CHECK-SAME: %{{[^:]*}}: !llvm.ptr<i64>) {
50// Get Tuple
51// CHECK-NEXT: %[[C0:.*]] = llvm.mlir.constant(0 : i64) : i32
52// CHECK-NEXT: %[[PTRS0:.*]] = llvm.getelementptr %[[BUFFERS]][%[[C0]]] : (!llvm.ptr<ptr<i8>>, i32) -> !llvm.ptr<ptr<i8>>
53// CHECK-NEXT: %[[PTR0:.*]] = llvm.load %[[PTRS0]] : !llvm.ptr<ptr<i8>>
54// Get individual output buffer
55// CHECK-NEXT: %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32
56// CHECK-NEXT: %[[PTRS1:.*]] = llvm.getelementptr %[[BUFFERS]][%[[C1]]] : (!llvm.ptr<ptr<i8>>, i32) -> !llvm.ptr<ptr<i8>>
57// CHECK-NEXT: %[[PTR1:.*]] = llvm.load %[[PTRS1]] : !llvm.ptr<ptr<i8>>
58// Store into tuple
59// CHECK-NEXT: %[[TUPLE:.*]] = llvm.bitcast %[[PTR0]] : !llvm.ptr<i8> to !llvm.ptr<ptr<i8>>
60// CHECK-NEXT: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
61// CHECK-NEXT: %[[TUPLE_ELEMENT:.*]] = llvm.getelementptr %[[TUPLE]][%[[C0]]] : (!llvm.ptr<ptr<i8>>, i32) -> !llvm.ptr<ptr<i8>>
62// CHECK-NEXT: llvm.store %[[PTR1]], %[[TUPLE_ELEMENT]] : !llvm.ptr<ptr<i8>>
63// Get tuple
64// CHECK-NEXT: %[[C0:.*]] = llvm.mlir.constant(0 : i64) : i32
65// CHECK-NEXT: %[[PTRS0:.*]] = llvm.getelementptr %[[BUFFERS]][%[[C0]]] : (!llvm.ptr<ptr<i8>>, i32) -> !llvm.ptr<ptr<i8>>
66// CHECK-NEXT: %[[PTR0:.*]] = llvm.load %[[PTRS0]] : !llvm.ptr<ptr<i8>>
67// Get individual output buffer
68// CHECK-NEXT: %[[C2:.*]] = llvm.mlir.constant(2 : i32) : i32
69// CHECK-NEXT: %[[PTRS2:.*]] = llvm.getelementptr %[[BUFFERS]][%[[C2]]] : (!llvm.ptr<ptr<i8>>, i32) -> !llvm.ptr<ptr<i8>>
70// CHECK-NEXT: %[[PTR2:.*]] = llvm.load %[[PTRS2]] : !llvm.ptr<ptr<i8>>
71// Store into Tuple
72// CHECK-NEXT: %[[TUPLE:.*]] = llvm.bitcast %[[PTR0]] : !llvm.ptr<i8> to !llvm.ptr<ptr<i8>>
73// CHECK-NEXT: %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32
74// CHECK-NEXT: %[[TUPLE_ELEMENT:.*]] = llvm.getelementptr %[[TUPLE]][%[[C1]]] : (!llvm.ptr<ptr<i8>>, i32) -> !llvm.ptr<ptr<i8>>
75// CHECK-NEXT: llvm.store %[[PTR2]], %[[TUPLE_ELEMENT]] : !llvm.ptr<ptr<i8>>
76// No return values
77// CHECK-NEXT:  return
78