1 /** 2 * Copyright 2019-2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_OPERATOR_H_ 17 #define MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_OPERATOR_H_ 18 19 #include <memory> 20 #include <vector> 21 #include "minddata/mindrecord/include/shard_task_list.h" 22 #include "minddata/dataset/include/dataset/constants.h" 23 24 namespace mindspore { 25 namespace mindrecord { 26 class __attribute__((visibility("default"))) ShardOperator { 27 public: 28 virtual ~ShardOperator() = default; 29 operator()30 Status operator()(ShardTaskList &tasks) { 31 RETURN_IF_NOT_OK(this->PreExecute(tasks)); 32 RETURN_IF_NOT_OK(this->Execute(tasks)); 33 RETURN_IF_NOT_OK(this->SufExecute(tasks)); 34 return Status::OK(); 35 } 36 HasChildOp()37 virtual bool HasChildOp() { return child_op_ != nullptr; } 38 SetChildOp(std::shared_ptr<ShardOperator> child_op)39 virtual Status SetChildOp(std::shared_ptr<ShardOperator> child_op) { 40 if (child_op != nullptr) { 41 child_op_ = child_op; 42 } 43 return Status::OK(); 44 } 45 GetChildOp()46 virtual std::shared_ptr<ShardOperator> GetChildOp() { return child_op_; } 47 PreExecute(ShardTaskList & tasks)48 virtual Status PreExecute(ShardTaskList &tasks) { return Status::OK(); } 49 50 virtual Status Execute(ShardTaskList &tasks) = 0; 51 SufExecute(ShardTaskList & tasks)52 virtual Status SufExecute(ShardTaskList &tasks) { return Status::OK(); } 53 GetNumSamples(int64_t dataset_size,int64_t num_classes)54 virtual int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) { return 0; } 55 UpdateShuffleMode(dataset::ShuffleMode shuffle_mode)56 virtual void UpdateShuffleMode(dataset::ShuffleMode shuffle_mode) { shuffle_mode_ = shuffle_mode; } 57 GetShuffleMode()58 virtual dataset::ShuffleMode GetShuffleMode() { return shuffle_mode_; } 59 SetShardSampleCount(const std::vector<uint32_t> & shard_sample_count)60 virtual void SetShardSampleCount(const std::vector<uint32_t> &shard_sample_count) { 61 shard_sample_count_ = shard_sample_count; 62 } 63 GetShardSampleCount()64 virtual std::vector<uint32_t> GetShardSampleCount() { return shard_sample_count_; } 65 66 private: 67 std::shared_ptr<ShardOperator> child_op_ = nullptr; 68 69 // indicate shard_id : inc_count 70 // // 0 : 15 - shard0 has 15 samples 71 // // 1 : 41 - shard1 has 26 samples 72 // // 2 : 58 - shard2 has 17 samples 73 std::vector<uint32_t> shard_sample_count_; 74 75 dataset::ShuffleMode shuffle_mode_ = dataset::ShuffleMode::kGlobal; 76 }; 77 } // namespace mindrecord 78 } // namespace mindspore 79 #endif // MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_OPERATOR_H_ 80