• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "minddata/dataset/core/tensor_helpers.h"
17 #include "minddata/dataset/util/log_adapter.h"
18 #include "minddata/dataset/util/status.h"
19 
20 namespace mindspore {
21 namespace dataset {
22 
IndexGeneratorHelper(int8_t depth,std::vector<dsize_t> * numbers,const std::vector<mindspore::dataset::SliceOption> & slice_list,std::vector<std::vector<dsize_t>> * matrix)23 void IndexGeneratorHelper(int8_t depth, std::vector<dsize_t> *numbers,
24                           const std::vector<mindspore::dataset::SliceOption> &slice_list,
25                           std::vector<std::vector<dsize_t>> *matrix) {
26   if (numbers == nullptr || matrix == nullptr) {
27     MS_LOG(ERROR) << "Invalid input pointer, can't be NULL";
28     return;
29   }
30   // for loop changes if its an index instead of a slice object
31   if (depth > 0) {
32     int8_t new_depth = depth - 1;
33     // depth is always less than or equal to numbers->size() (based on the caller functions)
34     size_t curr_ind = static_cast<size_t>(numbers->size() - static_cast<size_t>(depth));
35     if (curr_ind >= slice_list.size()) {
36       MS_LOG(ERROR) << "The index is out of range in slice_list.";
37       return;
38     }
39 
40     if (slice_list[curr_ind].slice_.valid()) {
41       dsize_t increment = slice_list[curr_ind].slice_.step_;
42 
43       if (increment > 0) {
44         for (dsize_t i = slice_list[curr_ind].slice_.start_; i < slice_list[curr_ind].slice_.stop_;
45              i = i + slice_list[curr_ind].slice_.step_) {
46           (*numbers)[curr_ind] = i;
47           IndexGeneratorHelper(new_depth, numbers, slice_list, matrix);
48         }
49       } else {
50         for (dsize_t j = slice_list[curr_ind].slice_.start_; j > slice_list[curr_ind].slice_.stop_;
51              j = j + slice_list[curr_ind].slice_.step_) {
52           (*numbers)[curr_ind] = j;
53           IndexGeneratorHelper(new_depth, numbers, slice_list, matrix);
54         }
55       }
56     } else {
57       for (size_t k = 0; k < slice_list[curr_ind].indices_.size(); k++) {
58         (*numbers)[curr_ind] = slice_list[curr_ind].indices_[k];
59         IndexGeneratorHelper(new_depth, numbers, slice_list, matrix);
60       }
61     }
62 
63   } else {
64     (*matrix).emplace_back((*numbers));
65   }
66 }
67 
68 // Used to generate slice indices
IndexGenerator(const std::vector<mindspore::dataset::SliceOption> & slice_list)69 std::vector<std::vector<dsize_t>> IndexGenerator(const std::vector<mindspore::dataset::SliceOption> &slice_list) {
70   int8_t depth = slice_list.size();
71   std::vector<dsize_t> numbers(depth, 0);
72   std::vector<std::vector<dsize_t>> matrix(0, std::vector<dsize_t>(depth, 0));
73 
74   IndexGeneratorHelper(depth, &numbers, slice_list, &matrix);
75 
76   return matrix;
77 }
78 }  // namespace dataset
79 }  // namespace mindspore
80