Home
last modified time | relevance | path

Searched refs:key_shape (Results 1 – 4 of 4) sorted by relevance

/third_party/mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/
Dpull_kernel.h54 auto key_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); in Init() local
55 for (size_t i = 0; i < key_shape.size(); i++) { in Init()
56 keys_size_ *= key_shape[i]; in Init()
/third_party/mindspore/mindspore/ccsrc/minddata/dataset/api/
Ddatasets.cc758 auto key_shape = column.find("shape"); in parse_column() local
759 if (key_shape != column.end()) { in parse_column()
760 shape.insert(shape.end(), (*key_shape).begin(), (*key_shape).end()); in parse_column()
774 auto key_shape = it_child.value().find("shape"); in parse_column() local
775 if (key_shape != it_child.value().end()) { in parse_column()
776 shape.insert(shape.end(), (*key_shape).begin(), (*key_shape).end()); in parse_column()
/third_party/mindspore/mindspore/parallel/nn/
Dtransformer.py942 key_shape = F.shape(key_tensor)
943 key_tensor = F.reshape(key_tensor, (-1, key_shape[-1]))
1230 self.key_shape = (batch_size, num_heads, size_per_head, seq_length)
1233 … self.key_past = Parameter(Tensor(np.zeros(shape=self.key_shape), self.dtype), name="key_past")
1548 self.key_shape = (batch_size, num_heads, size_per_head, tgt_seq_length)
1551 … self.key_past = Parameter(Tensor(np.zeros(shape=self.key_shape), self.dtype), name="key_past")
/third_party/mindspore/mindspore/ops/operations/
Dother_ops.py723 def infer_shape(self, key_shape, weight_shape): argument