Home
last modified time | relevance | path

Searched defs:num_head (Results 1 – 8 of 8) sorted by relevance

/third_party/mindspore/mindspore-src/source/tests/st/mindscience/mindsponge/mindsponge/cell/
Dmsa.py32 …def __init__(self, num_head, key_dim, gating, msa_act_dim, pair_act_dim, batch_size=None, slice_nu… argument
115 def __init__(self, num_head, key_dim, gating, msa_act_dim, batch_size=None, slice_num=0): argument
167 def __init__(self, num_head, gating, msa_act_dim, batch_size=None, slice_num=0): argument
Dbasic.py33 def __init__(self, num_head, hidden_size, gating, q_data_dim, m_data_dim, output_dim, argument
180 def __init__(self, num_head, gating, input_dim, output_dim, batch_size=None): argument
Dequivariant.py34 …def __init__(self, num_head, num_scalar_qk, num_scalar_v, num_point_v, num_point_qk, num_channel, … argument
Dtriangle.py35 …def __init__(self, orientation, num_head, key_dim, gating, layer_norm_dim, batch_size=None, slice_… argument
/third_party/mindspore/mindspore-src/source/tests/ut/cpp/ops/
Dtest_ops_paged_attention.cc42 ValuePtr num_head; member
59 auto num_head = param.num_head->ToAbstract(); in TEST_P() local
Dtest_ops_paged_attention_mask.cc44 ValuePtr num_head; member
62 auto num_head = param.num_head->ToAbstract(); in TEST_P() local
/third_party/mindspore/mindspore-src/source/tests/ut/python/parallel/
Dtest_reshape_and_cache.py27 def generate_inputs(bs=1, seq_len=1, num_head=40, head_dim=128, block_size=16, max_seq=2048): argument
/third_party/mindspore/mindspore-src/source/mindspore/lite/tools/optimizer/fusion/
Dflash_attention_fusion.cc308 …mForIpAdapterPattern(const CNodePtr &q_trans_BNSD, const CNodePtr &k_trans_BNDS, int64_t *num_head, in GetParamForIpAdapterPattern()
1639 … const AnfNodePtr &v_trans, int64_t num_head, int64_t next_token, in CreateFAForSD15()
1670 … const AnfNodePtr &v_trans, const AnfNodePtr &pse, int64_t num_head, in CreateFAWithPadAndPse()
1787 int64_t num_head = input_tensor_q_shape[kNumIndex1]; in CreateFlashAttentionNodeForMsSDXL() local
1868 int64_t num_head = input_tensor_q_shape[kNumIndex1]; in CreateFlashAttentionNodeForMsSDPseShift() local
1970 int64_t num_head = input_tensor_q_shape[kNumIndex1]; in CreateFlashAttentionNodeForMsSD21() local
2032 int64_t num_head = 0; in CreateFlashAttentionNodeForVideoComposer() local
2123 int64_t num_head = 0; in CreateFlashAttentionNodeForSD() local
2201 int64_t num_head = 0; in CreateFlashAttentionNodeForSDPreMul() local
2318 int64_t num_head = 0; in CreateFlashAttentionNodeForSDWithoutCast() local
[all …]