• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_OPERATOR_H_
17 #define MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_OPERATOR_H_
18 
19 #include <memory>
20 #include <vector>
21 #include "minddata/mindrecord/include/shard_task_list.h"
22 #include "minddata/dataset/include/dataset/constants.h"
23 
24 namespace mindspore {
25 namespace mindrecord {
26 class MINDRECORD_API ShardOperator {
27  public:
28   virtual ~ShardOperator() = default;
29 
operator()30   Status operator()(ShardTaskList &tasks) {  // NOLINT
31     RETURN_IF_NOT_OK_MR(this->PreExecute(tasks));
32     RETURN_IF_NOT_OK_MR(this->Execute(tasks));
33     RETURN_IF_NOT_OK_MR(this->SufExecute(tasks));
34     return Status::OK();
35   }
36 
HasChildOp()37   virtual bool HasChildOp() { return child_op_ != nullptr; }
38 
SetChildOp(const std::shared_ptr<ShardOperator> & child_op)39   virtual Status SetChildOp(const std::shared_ptr<ShardOperator> &child_op) {
40     if (child_op != nullptr) {
41       child_op_ = child_op;
42     }
43     return Status::OK();
44   }
45 
GetChildOp()46   virtual std::shared_ptr<ShardOperator> GetChildOp() { return child_op_; }
47 
PreExecute(ShardTaskList & tasks)48   virtual Status PreExecute(ShardTaskList &tasks) { return Status::OK(); }  // NOLINT
49 
50   virtual Status Execute(ShardTaskList &tasks) = 0;  // NOLINT
51 
SufExecute(ShardTaskList & tasks)52   virtual Status SufExecute(ShardTaskList &tasks) { return Status::OK(); }  // NOLINT
53 
54   /// \brief compute actual the num_samples via loading data
GetNumSamples(int64_t dataset_size,int64_t num_classes)55   virtual int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) { return 0; }
56 
57   /// \brief Getter the number of samples which is set via python api
GetNumSamples()58   virtual int64_t GetNumSamples() const { return num_samples_; }
59 
60   /// \brief Setter the number of samples in python
SetNumSamples(int64_t num_samples)61   virtual void SetNumSamples(int64_t num_samples) { num_samples_ = num_samples; }
62 
UpdateShuffleMode(dataset::ShuffleMode shuffle_mode)63   virtual void UpdateShuffleMode(dataset::ShuffleMode shuffle_mode) { shuffle_mode_ = shuffle_mode; }
64 
GetShuffleMode()65   virtual dataset::ShuffleMode GetShuffleMode() { return shuffle_mode_; }
66 
SetShardSampleCount(const std::vector<int64_t> & shard_sample_count)67   virtual void SetShardSampleCount(const std::vector<int64_t> &shard_sample_count) {
68     shard_sample_count_ = shard_sample_count;
69   }
70 
GetShardSampleCount()71   virtual std::vector<int64_t> GetShardSampleCount() { return shard_sample_count_; }
72 
73  private:
74   int64_t num_samples_ = 0;
75   std::shared_ptr<ShardOperator> child_op_ = nullptr;
76   // indicate shard_id : inc_count
77   // 0 : 15  -  shard0 has 15 samples
78   // 1 : 41  -  shard1 has 26 samples
79   // 2 : 58  -  shard2 has 17 samples
80   std::vector<int64_t> shard_sample_count_;
81   dataset::ShuffleMode shuffle_mode_ = dataset::ShuffleMode::kGlobal;
82 };
83 }  // namespace mindrecord
84 }  // namespace mindspore
85 #endif  // MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_OPERATOR_H_
86