Searched refs:segment_ids_shape (Results 1 – 6 of 6) sorted by relevance
/third_party/mindspore/mindspore/core/ops/ |
D | unsorted_segment_sum.cc | 42 …auto segment_ids_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape… in UnsortedSegmentSumInfer() local 43 …(void)CheckAndConvertUtils::CheckInteger("segment_ids_shape", SizeToLong(segment_ids_shape.size())… in UnsortedSegmentSumInfer() 46 int64_t(segment_ids_shape.size()), prim_name); in UnsortedSegmentSumInfer() 49 (segment_ids_shape.end() != find(segment_ids_shape.begin(), segment_ids_shape.end(), -1))) { in UnsortedSegmentSumInfer() 50 size_t size = segment_ids_shape.size(); in UnsortedSegmentSumInfer() 52 …CheckAndConvertUtils::Check("segment_ids_shp", segment_ids_shape[i], kEqual, "x_shape", x_shape[i]… in UnsortedSegmentSumInfer() 59 size_t size_segment_ids_shp = segment_ids_shape.size(); in UnsortedSegmentSumInfer()
|
/third_party/mindspore/mindspore/core/abstract/ |
D | prim_arrays.cc | 237 auto segment_ids_shape = segment_ids->shape()->shape(); in InferImplUnsortedSegmentSum() local 250 shape.insert(shape.end(), x_shape.begin() + segment_ids_shape.size(), x_shape.end()); in InferImplUnsortedSegmentSum() 252 for (size_t i = 0; i < segment_ids_shape.size(); i++) { in InferImplUnsortedSegmentSum() 253 if (x_shape[i] != segment_ids_shape[i]) { in InferImplUnsortedSegmentSum() 265 …std::any_of(segment_ids_shape.begin(), segment_ids_shape.end(), [](int64_t dim) { return dim == Sh… in InferImplUnsortedSegmentSum() 267 for (size_t i = 0; i < segment_ids_shape.size(); i++) { in InferImplUnsortedSegmentSum() 268 if (x_shape[i] != segment_ids_shape[i]) { in InferImplUnsortedSegmentSum() 277 …min_shape.insert(min_shape.end(), x_shape_min.begin() + segment_ids_shape.size(), x_shape_min.end(… in InferImplUnsortedSegmentSum() 278 …max_shape.insert(max_shape.end(), x_shape_max.begin() + segment_ids_shape.size(), x_shape_max.end(… in InferImplUnsortedSegmentSum() 292 auto segment_ids_shape = segment_ids->shape()->shape(); in InferImplUnsortedSegmentMax() local [all …]
|
/third_party/mindspore/mindspore/lite/src/runtime/kernel/arm/fp16_grad/ |
D | unsorted_segment_sum_fp16.cc | 42 auto segment_ids_shape = in_tensors_.at(1)->shape(); in Init() local 48 if (i >= segment_ids_shape.size()) { in Init()
|
/third_party/mindspore/mindspore/lite/src/runtime/kernel/arm/fp32_grad/ |
D | unsorted_segment_sum.cc | 42 auto segment_ids_shape = in_tensors_.at(1)->shape(); in Init() local 48 if (i >= segment_ids_shape.size()) { in Init()
|
/third_party/mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/ |
D | unsorted_segment_sum_cpu_kernel.cc | 34 auto segment_ids_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); in InitKernel() local 41 if (i >= segment_ids_shape.size()) { in InitKernel()
|
/third_party/mindspore/mindspore/ops/operations/ |
D | array_ops.py | 2158 segment_ids_shape = segment_ids['shape'] 2162 validator.check_equal_int(len(segment_ids_shape), 1, "rank of segment_ids_shape", self.name) 2165 if -1 not in x_shape and -1 not in segment_ids_shape: 2168 'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name) 2284 segment_ids_shape = segment_ids['shape'] 2289 validator.check_equal_int(len(segment_ids_shape), 1, "rank of segment_ids_shape", self.name) 2292 if -1 not in x_shape and -1 not in segment_ids_shape: 2295 'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name) 2346 segment_ids_shape = segment_ids['shape'] 2352 validator.check_equal_int(len(segment_ids_shape), 1, "rank of segment_ids_shape", self.name) [all …]
|