Home
last modified time | relevance | path

Searched refs:x2_shape (Results 1 – 7 of 7) sorted by relevance

/third_party/mindspore/mindspore/ops/composite/
Dmath_ops.py155 def _axes_int_check(x1_shape, x2_shape, axes, prim_name=None): argument
166 if axes > len(x1_shape) or axes > len(x2_shape):
170 x2_ind = tuple(range(len(x2_shape))[:axes])
177 def _validate_axes(x1_shape, x2_shape, axes, prim_name=None): argument
184 shapes = [x1_shape, x2_shape]
208 if x1_shape[axes[0][i]] != x2_shape[axes[1][i]]:
210 if x1_shape[axes[0][i]] != x2_shape[axes[1][len(axes[0])-1-i]]:
284 x2_shape = shape_op(x2)
290 axes = _axes_int_check(x1_shape, x2_shape, axes, prim_name)
291 _validate_axes(x1_shape, x2_shape, axes, prim_name)
[all …]
/third_party/mindspore/mindspore/nn/layer/
Dmath.py787 def check_col_row_equal(x1_shape, x2_shape, transpose_x1, transpose_x2, prim_name=None): argument
793 if len(x2_shape) == 1:
795 x2_shape = x2_shape + (1,)
797 x2_last = x2_shape[-2:]
805 def matmul_op_select(x1_shape, x2_shape, transpose_x1, transpose_x2): argument
807 x1_dim, x2_dim = len(x1_shape), len(x2_shape)
886 x2_shape = self.shape_op(x2)
887 check_col_row_equal(x1_shape, x2_shape, self.transpose_x1, self.transpose_x2, self.cls_name)
888 matmul_op = matmul_op_select(x1_shape, x2_shape, self.transpose_x1, self.transpose_x2)
890 x1_dim, x2_dim = len(x1_shape), len(x2_shape)
[all …]
/third_party/mindspore/mindspore/ops/_op_impl/_custom_op/
Dbatch_matmul_impl.py198 x2_shape = input_x2.get("shape")
202 input_shape = (tuple(x1_shape), tuple(x2_shape), dtype, transpose_a, transpose_b)
210 input2_shape = _get_flattern_shape(x2_shape)
/third_party/mindspore/mindspore/ops/operations/
D_thor_ops.py695 def infer_shape(self, x1_shape, x2_shape, x3_shape): argument
696 return x2_shape
796 def infer_shape(self, x1_shape, x2_shape, x3_shape): argument
D_grad_ops.py830 def infer_shape(self, x1_shape, x2_shape, grad_shape): argument
882 def infer_shape(self, x1_shape, x2_shape, grad_shape): argument
925 def infer_shape(self, x1_shape, x2_shape, grad_shape): argument
926 return x2_shape
Dnn_ops.py8018 def infer_shape(self, x1_shape, x2_shape): argument
8020 validator.check("x2 shape", len(x2_shape), "", 1, Rel.EQ, self.name)
8021 … validator.check("size of x2", x2_shape[0], "x1's first dimension", x1_shape[0], Rel.EQ, self.name)
8022 return x2_shape
/third_party/mindspore/mindspore/ccsrc/backend/optimizer/trt_pass/
Dtrt_op_converter.cc92 const std::vector<size_t> &x2_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, 1); in AddElementLayer() local
121 auto *x2 = Broadcast(ToTensor(&inputs[1], x2_shape, context), x2_shape); in AddElementLayer()
557 const std::vector<size_t> &x2_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, 1); in MS_TRT_CONVERTER_FUNC_REG() local
559 nvinfer1::ITensor *x2 = ToTensor(&inputs[1], x2_shape, context); in MS_TRT_CONVERTER_FUNC_REG()