1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "minddata/dataset/core/tensor_helpers.h"
17 #include <memory>
18
19 #include "include/dataset/constants.h"
20 #include "include/dataset/transforms.h"
21 #include "minddata/dataset/util/log_adapter.h"
22 namespace mindspore {
23 namespace dataset {
24
IndexGeneratorHelper(int8_t depth,std::vector<dsize_t> * numbers,const std::vector<mindspore::dataset::SliceOption> & slice_list,std::vector<std::vector<dsize_t>> * matrix)25 void IndexGeneratorHelper(int8_t depth, std::vector<dsize_t> *numbers,
26 const std::vector<mindspore::dataset::SliceOption> &slice_list,
27 std::vector<std::vector<dsize_t>> *matrix) {
28 if (numbers == nullptr || matrix == nullptr) {
29 MS_LOG(ERROR) << "Invalid input pointer, can't be NULL";
30 return;
31 }
32 // for loop changes if its an index instead of a slice object
33 if (depth > 0) {
34 int8_t new_depth = depth - 1;
35 // depth is always less than or equal to numbers->size() (based on the caller functions)
36 size_t curr_ind = static_cast<size_t>(numbers->size() - static_cast<size_t>(depth));
37 if (curr_ind >= slice_list.size()) {
38 MS_LOG(ERROR) << "The index is out of range in slice_list.";
39 return;
40 }
41
42 if (slice_list[curr_ind].slice_.valid()) {
43 dsize_t increment = slice_list[curr_ind].slice_.step_;
44
45 if (increment > 0) {
46 for (dsize_t i = slice_list[curr_ind].slice_.start_; i < slice_list[curr_ind].slice_.stop_;
47 i = i + slice_list[curr_ind].slice_.step_) {
48 (*numbers)[curr_ind] = i;
49 IndexGeneratorHelper(new_depth, numbers, slice_list, matrix);
50 }
51 } else {
52 for (dsize_t j = slice_list[curr_ind].slice_.start_; j > slice_list[curr_ind].slice_.stop_;
53 j = j + slice_list[curr_ind].slice_.step_) {
54 (*numbers)[curr_ind] = j;
55 IndexGeneratorHelper(new_depth, numbers, slice_list, matrix);
56 }
57 }
58 } else {
59 for (size_t k = 0; k < slice_list[curr_ind].indices_.size(); k++) {
60 (*numbers)[curr_ind] = slice_list[curr_ind].indices_[k];
61 IndexGeneratorHelper(new_depth, numbers, slice_list, matrix);
62 }
63 }
64
65 } else {
66 (*matrix).emplace_back((*numbers));
67 }
68 }
69
70 // Used to generate slice indices
IndexGenerator(const std::vector<mindspore::dataset::SliceOption> & slice_list)71 std::vector<std::vector<dsize_t>> IndexGenerator(const std::vector<mindspore::dataset::SliceOption> &slice_list) {
72 int8_t depth = slice_list.size();
73 std::vector<dsize_t> numbers(depth, 0);
74 std::vector<std::vector<dsize_t>> matrix(0, std::vector<dsize_t>(depth, 0));
75
76 IndexGeneratorHelper(depth, &numbers, slice_list, &matrix);
77
78 return matrix;
79 }
80 } // namespace dataset
81 } // namespace mindspore
82