• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_OPERATOR_H_
17 #define MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_OPERATOR_H_
18 
19 #include <memory>
20 #include <vector>
21 #include "minddata/mindrecord/include/shard_task_list.h"
22 #include "minddata/dataset/include/dataset/constants.h"
23 
24 namespace mindspore {
25 namespace mindrecord {
26 class __attribute__((visibility("default"))) ShardOperator {
27  public:
28   virtual ~ShardOperator() = default;
29 
operator()30   Status operator()(ShardTaskList &tasks) {
31     RETURN_IF_NOT_OK(this->PreExecute(tasks));
32     RETURN_IF_NOT_OK(this->Execute(tasks));
33     RETURN_IF_NOT_OK(this->SufExecute(tasks));
34     return Status::OK();
35   }
36 
HasChildOp()37   virtual bool HasChildOp() { return child_op_ != nullptr; }
38 
SetChildOp(std::shared_ptr<ShardOperator> child_op)39   virtual Status SetChildOp(std::shared_ptr<ShardOperator> child_op) {
40     if (child_op != nullptr) {
41       child_op_ = child_op;
42     }
43     return Status::OK();
44   }
45 
GetChildOp()46   virtual std::shared_ptr<ShardOperator> GetChildOp() { return child_op_; }
47 
PreExecute(ShardTaskList & tasks)48   virtual Status PreExecute(ShardTaskList &tasks) { return Status::OK(); }
49 
50   virtual Status Execute(ShardTaskList &tasks) = 0;
51 
SufExecute(ShardTaskList & tasks)52   virtual Status SufExecute(ShardTaskList &tasks) { return Status::OK(); }
53 
GetNumSamples(int64_t dataset_size,int64_t num_classes)54   virtual int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) { return 0; }
55 
UpdateShuffleMode(dataset::ShuffleMode shuffle_mode)56   virtual void UpdateShuffleMode(dataset::ShuffleMode shuffle_mode) { shuffle_mode_ = shuffle_mode; }
57 
GetShuffleMode()58   virtual dataset::ShuffleMode GetShuffleMode() { return shuffle_mode_; }
59 
SetShardSampleCount(const std::vector<uint32_t> & shard_sample_count)60   virtual void SetShardSampleCount(const std::vector<uint32_t> &shard_sample_count) {
61     shard_sample_count_ = shard_sample_count;
62   }
63 
GetShardSampleCount()64   virtual std::vector<uint32_t> GetShardSampleCount() { return shard_sample_count_; }
65 
66  private:
67   std::shared_ptr<ShardOperator> child_op_ = nullptr;
68 
69   // indicate shard_id : inc_count
70   //   // 0 : 15  -  shard0 has 15 samples
71   //     // 1 : 41  -  shard1 has 26 samples
72   //       // 2 : 58  -  shard2 has 17 samples
73   std::vector<uint32_t> shard_sample_count_;
74 
75   dataset::ShuffleMode shuffle_mode_ = dataset::ShuffleMode::kGlobal;
76 };
77 }  // namespace mindrecord
78 }  // namespace mindspore
79 #endif  // MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_OPERATOR_H_
80