• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import numpy as np
2from mindspore import Tensor
3from mindspore.parallel.nn.layers import FixedSparseAttention
4import mindspore.context as context
5
6context.set_context(device_target="Ascend")
7
8
9def test_net():
10    np.random.seed(0)
11    bs = 2  # batch size
12    heads = 2
13    seq_len = 1024  # this op is designed for seq_len = 1024
14    size_per_head = 128  # maximum size per head value is 128
15
16    block_size = 64  # block size is designed to be 64
17    fixed_sparse = FixedSparseAttention(bs, heads, size_per_head, block_size)
18    q = np.random.rand(bs, seq_len, heads * size_per_head)
19    q = q.astype(np.float16)
20    k = np.random.rand(bs, seq_len, heads * size_per_head)
21    k = k.astype(np.float16)
22    v = np.random.rand(bs, seq_len, heads * size_per_head)
23    v = v.astype(np.float16)
24    attention_mask = np.ones((bs, seq_len, seq_len), dtype=np.float32)
25    out = fixed_sparse(Tensor(q), Tensor(k), Tensor(v), Tensor(attention_mask))
26    out_np = out.asnumpy()
27    print("local output: ", out_np[0, 0])
28