Searched refs:axis_val (Results 1 – 3 of 3) sorted by relevance
/third_party/mindspore/mindspore/core/ops/ |
D | gather.cc | 49 int64_t axis_val = 0; in GatherInfer() local 58 axis_val = *static_cast<int64_t *>(axis_tensor->data_c()); in GatherInfer() 61 axis_val = GetValue<int64_t>(axis->BuildValue()); in GatherInfer() 68 …CheckAndConvertUtils::CheckInRange<int64_t>("axis", axis_val, kIncludeLeft, {-params_rank, params_… in GatherInfer() 75 if (!(-params_rank <= axis_val) || !(axis_val < params_rank)) { in GatherInfer() 77 << "Got " << axis_val << "."; in GatherInfer() 79 if (axis_val < 0) { in GatherInfer() 80 axis_val += params_rank; in GatherInfer() 82 …auto calc_shape = [axis_val](const ShapeVector &ind_vec, const ShapeVector ¶ms_vec) -> ShapeVe… in GatherInfer() 84 (void)std::copy(params_vec.begin(), params_vec.begin() + axis_val, std::back_inserter(out_vec)); in GatherInfer() [all …]
|
/third_party/mindspore/mindspore/lite/src/delegate/tensorrt/op/ |
D | softmax_tensorrt.cc | 71 auto axis_val = std::vector<int64_t>(axis->begin(), axis->end()); in AddSoftMaxOp() local 73 if (axis_val.size() != 1) { in AddSoftMaxOp() 77 if (axis_val[0] >= this->tensorrt_in_tensors_[0].trt_tensor_->getDimensions().nbDims) { in AddSoftMaxOp() 81 int64_t axis_format_value = axis_val[0]; in AddSoftMaxOp() 85 axis_format_value = ConvertAxisFromNHWC2NCHW(axis_val[0]); in AddSoftMaxOp()
|
/third_party/mindspore/mindspore/core/abstract/ |
D | prim_arrays.cc | 588 int64_t axis_val = 0; in InferImplGatherV2() local 598 axis_val = *static_cast<int64_t *>(axis_tensor->data_c()); in InferImplGatherV2() 601 axis_val = GetValue<int64_t>(axis->BuildValue()); in InferImplGatherV2() 614 if (-params_rank > axis_val || axis_val >= params_rank) { in InferImplGatherV2() 616 << "Got " << axis_val << "."; in InferImplGatherV2() 618 if (axis_val < 0) { in InferImplGatherV2() 619 axis_val += params_rank; in InferImplGatherV2() 621 …auto calc_shape = [axis_val](const ShapeVector &ind_vec, const ShapeVector ¶ms_vec) -> ShapeVe… in InferImplGatherV2() 623 std::copy(params_vec.begin(), params_vec.begin() + axis_val, std::back_inserter(out_vec)); in InferImplGatherV2() 625 copy(params_vec.begin() + axis_val + 1, params_vec.end(), std::back_inserter(out_vec)); in InferImplGatherV2()
|