1 /** 2 * Copyright 2019-2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_OPERATOR_H_ 17 #define MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_OPERATOR_H_ 18 19 #include <memory> 20 #include <vector> 21 #include "minddata/mindrecord/include/shard_task_list.h" 22 #include "minddata/dataset/include/dataset/constants.h" 23 24 namespace mindspore { 25 namespace mindrecord { 26 class MINDRECORD_API ShardOperator { 27 public: 28 virtual ~ShardOperator() = default; 29 operator()30 Status operator()(ShardTaskList &tasks) { // NOLINT 31 RETURN_IF_NOT_OK_MR(this->PreExecute(tasks)); 32 RETURN_IF_NOT_OK_MR(this->Execute(tasks)); 33 RETURN_IF_NOT_OK_MR(this->SufExecute(tasks)); 34 return Status::OK(); 35 } 36 HasChildOp()37 virtual bool HasChildOp() { return child_op_ != nullptr; } 38 SetChildOp(const std::shared_ptr<ShardOperator> & child_op)39 virtual Status SetChildOp(const std::shared_ptr<ShardOperator> &child_op) { 40 if (child_op != nullptr) { 41 child_op_ = child_op; 42 } 43 return Status::OK(); 44 } 45 GetChildOp()46 virtual std::shared_ptr<ShardOperator> GetChildOp() { return child_op_; } 47 PreExecute(ShardTaskList & tasks)48 virtual Status PreExecute(ShardTaskList &tasks) { return Status::OK(); } // NOLINT 49 50 virtual Status Execute(ShardTaskList &tasks) = 0; // NOLINT 51 SufExecute(ShardTaskList & tasks)52 virtual Status SufExecute(ShardTaskList &tasks) { return Status::OK(); } // NOLINT 53 54 /// \brief compute actual the num_samples via loading data GetNumSamples(int64_t dataset_size,int64_t num_classes)55 virtual int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) { return 0; } 56 57 /// \brief Getter the number of samples which is set via python api GetNumSamples()58 virtual int64_t GetNumSamples() const { return num_samples_; } 59 60 /// \brief Setter the number of samples in python SetNumSamples(int64_t num_samples)61 virtual void SetNumSamples(int64_t num_samples) { num_samples_ = num_samples; } 62 UpdateShuffleMode(dataset::ShuffleMode shuffle_mode)63 virtual void UpdateShuffleMode(dataset::ShuffleMode shuffle_mode) { shuffle_mode_ = shuffle_mode; } 64 GetShuffleMode()65 virtual dataset::ShuffleMode GetShuffleMode() { return shuffle_mode_; } 66 SetShardSampleCount(const std::vector<int64_t> & shard_sample_count)67 virtual void SetShardSampleCount(const std::vector<int64_t> &shard_sample_count) { 68 shard_sample_count_ = shard_sample_count; 69 } 70 GetShardSampleCount()71 virtual std::vector<int64_t> GetShardSampleCount() { return shard_sample_count_; } 72 73 private: 74 int64_t num_samples_ = 0; 75 std::shared_ptr<ShardOperator> child_op_ = nullptr; 76 // indicate shard_id : inc_count 77 // 0 : 15 - shard0 has 15 samples 78 // 1 : 41 - shard1 has 26 samples 79 // 2 : 58 - shard2 has 17 samples 80 std::vector<int64_t> shard_sample_count_; 81 dataset::ShuffleMode shuffle_mode_ = dataset::ShuffleMode::kGlobal; 82 }; 83 } // namespace mindrecord 84 } // namespace mindspore 85 #endif // MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_OPERATOR_H_ 86