Home
last modified time | relevance | path

Searched refs:concat_dim (Results 1 – 15 of 15) sorted by relevance

/third_party/mindspore/tests/ut/python/parallel/
Dtest_alltoall.py132 self.alltoallv = AlltoAll(split_count=8, split_dim=2, concat_dim=3)
153 self.alltoallv = AlltoAll(split_count=7, split_dim=2, concat_dim=3)
175 self.alltoallv = AlltoAll(split_count=[8], split_dim=2, concat_dim=3)
197 self.alltoallv = AlltoAll(split_count=8, split_dim=4, concat_dim=3)
219 self.alltoallv = AlltoAll(split_count=8, split_dim=(3,), concat_dim=3)
241 self.alltoallv = AlltoAll(split_count=8, split_dim=3, concat_dim=4)
263 self.alltoallv = AlltoAll(split_count=8, split_dim=3, concat_dim=([3],))
285 self.alltoallv = AlltoAll(split_count=3, split_dim=3, concat_dim=3)
307 self.alltoallv = AlltoAll(split_count=8, split_dim=3, concat_dim=3, group=3)
/third_party/mindspore/mindspore/ccsrc/frontend/parallel/tensor_layout/
Dconstruct_operator.cc164 Status ConstructOperator::ConcatOP(int64_t concat_dim) { in ConcatOP() argument
165 if (LongToSize(concat_dim) >= tensor_shape_.size() || concat_dim < 0) { in ConcatOP()
166 … MS_LOG(ERROR) << "Invalid tensor dimension " << concat_dim << " when construct Concat operator!"; in ConcatOP()
169 ValuePtr attr_value = MakeValue(concat_dim); in ConcatOP()
203 int64_t concat_dim = args[TRANSFER_PERMUTE_CONCAT_DIM_INDEX]; in AlltoAllOP() local
214 if (LongToSize(concat_dim) >= tensor_shape_.size() || concat_dim < 0) { in AlltoAllOP()
239 ValuePtr attr_value_concat_dim = MakeValue(concat_dim); in AlltoAllOP()
Dtensor_redistribution.cc252 int64_t concat_dim = attrs[TRANSFER_PERMUTE_CONCAT_DIM_INDEX]; in ComputePermuteCost() local
253 if (concat_dim == 0) { in ComputePermuteCost()
278 int64_t concat_dim = attrs[TRANSFER_CONCAT_TENSOR_DIM_INDEX]; in ComputeConcatCost() local
279 if (concat_dim == 0) { in ComputeConcatCost()
Dconstruct_operator.h43 Status ConcatOP(int64_t concat_dim);
/third_party/mindspore/mindspore/ccsrc/backend/optimizer/ascend/mindir/
Dall_to_all_unify_mindir.cc131 int64_t concat_dim = AnfAlgo::GetNodeAttr<int64_t>(all_to_all, kAttrConcatDim); in CreateConcatNode() local
142 concat_dim = NormalizeDim(single_shape, concat_dim); in CreateConcatNode()
143 if (LongToSize(concat_dim) >= single_shape.size()) { in CreateConcatNode()
144 …MS_LOG(EXCEPTION) << "Invalid concat dim " << concat_dim << " is greater than shape size " << sing… in CreateConcatNode()
146 single_shape[LongToSize(concat_dim)] *= static_cast<size_t>(split_count); in CreateConcatNode()
149 AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue<int64_t>(concat_dim), concat); in CreateConcatNode()
/third_party/mindspore/tests/ut/cpp/parallel/tensor_layout/
Dconstruct_operator_test.cc120 int64_t concat_dim = 0; in TEST_F() local
121 ASSERT_EQ(constructor.ConcatOP(concat_dim), Status::SUCCESS); in TEST_F()
132 int64_t concat_dim = 1; in TEST_F() local
135 Args args = {split_count, split_dim, concat_dim, dev_dim, dev_num}; in TEST_F()
/third_party/mindspore/mindspore/lite/tools/converter/parser/caffe/
Dcaffe_concat_parser.cc36 MS_LOG(DEBUG) << "Concat dim , set axis: " << concatParam.concat_dim(); in Parse()
37 axis = concatParam.concat_dim(); in Parse()
/third_party/mindspore/mindspore/ops/operations/
Dcomm_ops.py681 def __init__(self, split_count, split_dim, concat_dim, group=GlobalComm.WORLD_COMM_GROUP): argument
686 validator.check_is_int(concat_dim, int)
689 self.concat_dim = concat_dim
701 x_shape[self.concat_dim] = x_shape[self.concat_dim] * self.split_count
/third_party/mindspore/mindspore/lite/tools/optimizer/parallel/
Dconv2d_info.cc373 int32_t concat_dim; in InferReplaceOp() local
375 concat_dim = kAxisN; in InferReplaceOp()
378 concat_dim = kAxisCOut; in InferReplaceOp()
380 concat_dim = kAxisH; in InferReplaceOp()
382 replace_op_ = CreateConcateNode(cnode_, parallel_output_nodes_, concat_dim, dev_num); in InferReplaceOp()
Doperator_info.cc138 int32_t concat_dim, size_t input_nodes_num) { in CreateConcateNode() argument
146 concat_prim->set_axis(concat_dim); in CreateConcateNode()
Doperator_info.h68 int32_t concat_dim, size_t input_nodes_num);
/third_party/mindspore/tests/ut/cpp/python_input/gtest_input/pre_activate/
Dall_to_all_unify_mindir_test.py48 altoall = AlltoAll(split_count=2, split_dim=2, concat_dim=3)
/third_party/mindspore/mindspore/ccsrc/transform/graph_ir/op_declare/
Dsplit_combination_ops_declare.cc49 {"axis", ATTR_DESC(concat_dim, AnyTraits<int64_t>())},
/third_party/mindspore/mindspore/ops/_grad/
Dgrad_comm_ops.py380 all_to_all_grad = AlltoAll(self.split_count, self.concat_dim, self.split_dim, self.group)
/third_party/mindspore/third_party/proto/caffe/
Dcaffe.proto520 optional uint32 concat_dim = 1 [default = 1]; field
1446 optional uint32 concat_dim = 65 [default = 1]; field