Home
last modified time | relevance | path

Searched refs:axis_val (Results 1 – 3 of 3) sorted by relevance

/third_party/mindspore/mindspore/core/ops/
Dgather.cc49 int64_t axis_val = 0; in GatherInfer() local
58 axis_val = *static_cast<int64_t *>(axis_tensor->data_c()); in GatherInfer()
61 axis_val = GetValue<int64_t>(axis->BuildValue()); in GatherInfer()
68 …CheckAndConvertUtils::CheckInRange<int64_t>("axis", axis_val, kIncludeLeft, {-params_rank, params_… in GatherInfer()
75 if (!(-params_rank <= axis_val) || !(axis_val < params_rank)) { in GatherInfer()
77 << "Got " << axis_val << "."; in GatherInfer()
79 if (axis_val < 0) { in GatherInfer()
80 axis_val += params_rank; in GatherInfer()
82 …auto calc_shape = [axis_val](const ShapeVector &ind_vec, const ShapeVector &params_vec) -> ShapeVe… in GatherInfer()
84 (void)std::copy(params_vec.begin(), params_vec.begin() + axis_val, std::back_inserter(out_vec)); in GatherInfer()
[all …]
/third_party/mindspore/mindspore/lite/src/delegate/tensorrt/op/
Dsoftmax_tensorrt.cc71 auto axis_val = std::vector<int64_t>(axis->begin(), axis->end()); in AddSoftMaxOp() local
73 if (axis_val.size() != 1) { in AddSoftMaxOp()
77 if (axis_val[0] >= this->tensorrt_in_tensors_[0].trt_tensor_->getDimensions().nbDims) { in AddSoftMaxOp()
81 int64_t axis_format_value = axis_val[0]; in AddSoftMaxOp()
85 axis_format_value = ConvertAxisFromNHWC2NCHW(axis_val[0]); in AddSoftMaxOp()
/third_party/mindspore/mindspore/core/abstract/
Dprim_arrays.cc588 int64_t axis_val = 0; in InferImplGatherV2() local
598 axis_val = *static_cast<int64_t *>(axis_tensor->data_c()); in InferImplGatherV2()
601 axis_val = GetValue<int64_t>(axis->BuildValue()); in InferImplGatherV2()
614 if (-params_rank > axis_val || axis_val >= params_rank) { in InferImplGatherV2()
616 << "Got " << axis_val << "."; in InferImplGatherV2()
618 if (axis_val < 0) { in InferImplGatherV2()
619 axis_val += params_rank; in InferImplGatherV2()
621 …auto calc_shape = [axis_val](const ShapeVector &ind_vec, const ShapeVector &params_vec) -> ShapeVe… in InferImplGatherV2()
623 std::copy(params_vec.begin(), params_vec.begin() + axis_val, std::back_inserter(out_vec)); in InferImplGatherV2()
625 copy(params_vec.begin() + axis_val + 1, params_vec.end(), std::back_inserter(out_vec)); in InferImplGatherV2()