Searched refs:key_shape (Results 1 – 4 of 4) sorted by relevance
/third_party/mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/ |
D | pull_kernel.h | 54 auto key_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); in Init() local 55 for (size_t i = 0; i < key_shape.size(); i++) { in Init() 56 keys_size_ *= key_shape[i]; in Init()
|
/third_party/mindspore/mindspore/ccsrc/minddata/dataset/api/ |
D | datasets.cc | 758 auto key_shape = column.find("shape"); in parse_column() local 759 if (key_shape != column.end()) { in parse_column() 760 shape.insert(shape.end(), (*key_shape).begin(), (*key_shape).end()); in parse_column() 774 auto key_shape = it_child.value().find("shape"); in parse_column() local 775 if (key_shape != it_child.value().end()) { in parse_column() 776 shape.insert(shape.end(), (*key_shape).begin(), (*key_shape).end()); in parse_column()
|
/third_party/mindspore/mindspore/parallel/nn/ |
D | transformer.py | 942 key_shape = F.shape(key_tensor) 943 key_tensor = F.reshape(key_tensor, (-1, key_shape[-1])) 1230 self.key_shape = (batch_size, num_heads, size_per_head, seq_length) 1233 … self.key_past = Parameter(Tensor(np.zeros(shape=self.key_shape), self.dtype), name="key_past") 1548 self.key_shape = (batch_size, num_heads, size_per_head, tgt_seq_length) 1551 … self.key_past = Parameter(Tensor(np.zeros(shape=self.key_shape), self.dtype), name="key_past")
|
/third_party/mindspore/mindspore/ops/operations/ |
D | other_ops.py | 723 def infer_shape(self, key_shape, weight_shape): argument
|