1// RUN: xla-opt %s -xla-legalize-xla-framework-to-llvm | FileCheck %s 2 3memref.global "private" constant @__constant_xf32 : memref<f32> = dense<42.0> 4 5func.func @buffer_type(%arg: !xla_framework.buffer {xla_framework.input_mapping = 0 : i64}) 6 attributes {xla_entry} { 7 %val = xla_framework.buffer_to_mem %arg : memref<f32> 8 %global = memref.get_global @__constant_xf32 : memref<f32> 9 memref.copy %global, %val : memref<f32> to memref<f32> 10 func.return 11} 12 13// CHECK-LABEL: @buffer_type 14// The following signature is always the same. 15// CHECK-SAME: %{{[^:]*}}: !llvm.ptr<i8> 16// CHECK-SAME: %{{[^:]*}}: !llvm.ptr<i8> 17// CHECK-SAME: %{{[^:]*}}: !llvm.ptr<ptr<i8>> 18// CHECK-SAME: %[[BUFFERS:.*]]: !llvm.ptr<ptr<i8>> 19// CHECK-SAME: %{{[^:]*}}: !llvm.ptr<i64> 20// CHECK-SAME: %{{[^:]*}}: !llvm.ptr<i64>) { 21// Retrieve pointer from the input as part of the function signature lowering. 22// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i64) : i32 23// CHECK: %[[PTRS:.*]] = llvm.getelementptr %[[BUFFERS]][%[[C0]]] : (!llvm.ptr<ptr<i8>>, i32) -> !llvm.ptr<ptr<i8>> 24// CHECK: %[[PTR0:.*]] = llvm.load %[[PTRS]] : !llvm.ptr<ptr<i8>> 25// CHECK: %[[INP0:.*]] = llvm.bitcast %[[PTR0]] : !llvm.ptr<i8> to !llvm.ptr<f32> 26// Create memref descriptor as the buffer_to_mem lowering. 27// CHECK: %[[MEMREF:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64)> 28// CHECK: %[[MEMREF1:.*]] = llvm.insertvalue %[[INP0]], %[[MEMREF]][0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)> 29// CHECK: %[[MEMREF:.*]] = llvm.insertvalue %[[INP0]], %[[MEMREF1]][1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)> 30// CHECK: %[[C0_0:.*]] = llvm.mlir.constant(0 : index) : i64 31// CHECK: llvm.insertvalue %[[C0_0:.*]], %[[MEMREF:.*]][2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)> 32// No return values in this case 33// CHECK: return 34 35 36func.func @return_tuple(%result0: !xla_framework.buffer, %result1: !xla_framework.buffer) 37 attributes {xla_entry, xla_framework.result_inner_mapping=[1,2], xla_framework.result_mapping=0} { 38 func.return 39} 40 41 42// CHECK-LABEL: @return_tuple 43// The following signature is always the same. 44// CHECK-SAME: %{{[^:]*}}: !llvm.ptr<i8> 45// CHECK-SAME: %{{[^:]*}}: !llvm.ptr<i8> 46// CHECK-SAME: %{{[^:]*}}: !llvm.ptr<ptr<i8>> 47// CHECK-SAME: %[[BUFFERS:.*]]: !llvm.ptr<ptr<i8>> 48// CHECK-SAME: %{{[^:]*}}: !llvm.ptr<i64> 49// CHECK-SAME: %{{[^:]*}}: !llvm.ptr<i64>) { 50// Get Tuple 51// CHECK-NEXT: %[[C0:.*]] = llvm.mlir.constant(0 : i64) : i32 52// CHECK-NEXT: %[[PTRS0:.*]] = llvm.getelementptr %[[BUFFERS]][%[[C0]]] : (!llvm.ptr<ptr<i8>>, i32) -> !llvm.ptr<ptr<i8>> 53// CHECK-NEXT: %[[PTR0:.*]] = llvm.load %[[PTRS0]] : !llvm.ptr<ptr<i8>> 54// Get individual output buffer 55// CHECK-NEXT: %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32 56// CHECK-NEXT: %[[PTRS1:.*]] = llvm.getelementptr %[[BUFFERS]][%[[C1]]] : (!llvm.ptr<ptr<i8>>, i32) -> !llvm.ptr<ptr<i8>> 57// CHECK-NEXT: %[[PTR1:.*]] = llvm.load %[[PTRS1]] : !llvm.ptr<ptr<i8>> 58// Store into tuple 59// CHECK-NEXT: %[[TUPLE:.*]] = llvm.bitcast %[[PTR0]] : !llvm.ptr<i8> to !llvm.ptr<ptr<i8>> 60// CHECK-NEXT: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32 61// CHECK-NEXT: %[[TUPLE_ELEMENT:.*]] = llvm.getelementptr %[[TUPLE]][%[[C0]]] : (!llvm.ptr<ptr<i8>>, i32) -> !llvm.ptr<ptr<i8>> 62// CHECK-NEXT: llvm.store %[[PTR1]], %[[TUPLE_ELEMENT]] : !llvm.ptr<ptr<i8>> 63// Get tuple 64// CHECK-NEXT: %[[C0:.*]] = llvm.mlir.constant(0 : i64) : i32 65// CHECK-NEXT: %[[PTRS0:.*]] = llvm.getelementptr %[[BUFFERS]][%[[C0]]] : (!llvm.ptr<ptr<i8>>, i32) -> !llvm.ptr<ptr<i8>> 66// CHECK-NEXT: %[[PTR0:.*]] = llvm.load %[[PTRS0]] : !llvm.ptr<ptr<i8>> 67// Get individual output buffer 68// CHECK-NEXT: %[[C2:.*]] = llvm.mlir.constant(2 : i32) : i32 69// CHECK-NEXT: %[[PTRS2:.*]] = llvm.getelementptr %[[BUFFERS]][%[[C2]]] : (!llvm.ptr<ptr<i8>>, i32) -> !llvm.ptr<ptr<i8>> 70// CHECK-NEXT: %[[PTR2:.*]] = llvm.load %[[PTRS2]] : !llvm.ptr<ptr<i8>> 71// Store into Tuple 72// CHECK-NEXT: %[[TUPLE:.*]] = llvm.bitcast %[[PTR0]] : !llvm.ptr<i8> to !llvm.ptr<ptr<i8>> 73// CHECK-NEXT: %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32 74// CHECK-NEXT: %[[TUPLE_ELEMENT:.*]] = llvm.getelementptr %[[TUPLE]][%[[C1]]] : (!llvm.ptr<ptr<i8>>, i32) -> !llvm.ptr<ptr<i8>> 75// CHECK-NEXT: llvm.store %[[PTR2]], %[[TUPLE_ELEMENT]] : !llvm.ptr<ptr<i8>> 76// No return values 77// CHECK-NEXT: return 78