1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include "minddata/dataset/engine/ir/datasetops/source/samplers/mindrecord_sampler_ir.h" 18 19 #include <memory> 20 #include <utility> 21 22 #ifndef ENABLE_ANDROID 23 #include "minddata/dataset/engine/datasetops/source/sampler/mind_record_sampler.h" 24 #include "minddata/mindrecord/include/shard_reader.h" 25 #endif 26 27 namespace mindspore { 28 namespace dataset { 29 #ifndef ENABLE_ANDROID 30 // This function not only creates a runtime sampler object, but also creates a ShardReader, 31 // which will also be needed to build a runtime MindRecordOp 32 // (cannot add another output parameter because it has to override base class's function) SamplerBuild(std::shared_ptr<SamplerRT> * sampler)33Status MindRecordSamplerObj::SamplerBuild(std::shared_ptr<SamplerRT> *sampler) { 34 shard_reader_ = std::make_unique<mindrecord::ShardReader>(); 35 *sampler = std::make_shared<MindRecordSamplerRT>(shard_reader_.get()); 36 return Status::OK(); 37 } 38 SamplerCopy()39std::shared_ptr<SamplerObj> MindRecordSamplerObj::SamplerCopy() { 40 auto sampler = std::make_shared<MindRecordSamplerObj>(); 41 return sampler; 42 } 43 44 // Function to acquire the unique pointer of the newly created ShardReader object 45 // Note this function can only be called after SamplerBuild is finished, and can only be called once. Otherwise this 46 // function will return error status. GetShardReader(std::unique_ptr<mindrecord::ShardReader> * shard_reader)47Status MindRecordSamplerObj::GetShardReader(std::unique_ptr<mindrecord::ShardReader> *shard_reader) { 48 CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_ != nullptr, "[Internal ERROR] Attempt to get an empty shard reader."); 49 *shard_reader = std::move(shard_reader_); 50 return Status::OK(); 51 } 52 #endif 53 } // namespace dataset 54 } // namespace mindspore 55