1# Copyright 2019-2023 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15""" 16This file contains standard format dataset loading classes. 17You can convert a dataset to a standard format using the following steps: 18 1. Use mindspore.mindrecord.FileWriter / tf.io.TFRecordWriter api to 19 convert dataset to MindRecord / TFRecord. 20 2. Use MindDataset / TFRecordDataset to load MindRecord / TFRecrod files. 21After declaring the dataset object, you can further apply dataset operations 22(e.g. filter, skip, concat, map, batch) on it. 23""" 24import platform 25 26import numpy as np 27 28import mindspore._c_dataengine as cde 29from mindspore import log as logger 30 31from .datasets import UnionBaseDataset, SourceDataset, MappableDataset, Shuffle, Schema, \ 32 shuffle_to_shuffle_mode, shuffle_to_bool 33from .datasets_user_defined import GeneratorDataset 34from .obs.obs_mindrecord_dataset import MindRecordFromOBS 35from .validators import check_csvdataset, check_minddataset, check_tfrecorddataset, check_obsminddataset 36from ...mindrecord.config import _get_enc_key, _get_dec_mode, _get_hash_mode, decrypt, verify_file_hash 37 38 39from ..core.validator_helpers import replace_none 40from . import samplers 41 42 43class CSVDataset(SourceDataset, UnionBaseDataset): 44 """ 45 A source dataset that reads and parses comma-separated values 46 `(CSV) <https://en.wikipedia.org/wiki/Comma-separated_values>`_ files as dataset. 47 48 The columns of generated dataset depend on the source CSV files. 49 50 Args: 51 dataset_files (Union[str, list[str]]): String or list of files to be read or glob strings to search 52 for a pattern of files. The list will be sorted in a lexicographical order. 53 field_delim (str, optional): A string that indicates the char delimiter to separate fields. 54 Default: ``','``. 55 column_defaults (list, optional): List of default values for the CSV field. Default: ``None``. Each item 56 in the list is either a valid type (float, int, or string). If this is not provided, treats all 57 columns as string type. 58 column_names (list[str], optional): List of column names of the dataset. Default: ``None``. If this 59 is not provided, infers the column_names from the first row of CSV file. 60 num_samples (int, optional): The number of samples to be included in the dataset. 61 Default: ``None``, will include all images. 62 num_parallel_workers (int, optional): Number of worker threads to read the data. 63 Default: ``None``, will use global default workers(8), it can be set 64 by :func:`mindspore.dataset.config.set_num_parallel_workers` . 65 shuffle (Union[bool, Shuffle], optional): Perform reshuffling of the data every epoch. 66 Default: ``Shuffle.GLOBAL`` . Bool type and Shuffle enum are both supported to pass in. 67 If `shuffle` is ``False`` , no shuffling will be performed. 68 If `shuffle` is ``True`` , performs global shuffle. 69 There are three levels of shuffling, desired shuffle enum defined by :class:`mindspore.dataset.Shuffle` . 70 71 - ``Shuffle.GLOBAL`` : Shuffle both the files and samples, same as setting shuffle to True. 72 73 - ``Shuffle.FILES`` : Shuffle files only. 74 75 num_shards (int, optional): Number of shards that the dataset will be divided into. Default: ``None`` . 76 When this argument is specified, `num_samples` reflects the maximum sample number of per shard. 77 shard_id (int, optional): The shard ID within `num_shards` . Default: ``None``. This 78 argument can only be specified when `num_shards` is also specified. 79 cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing. More details: 80 `Single-Node Data Cache <https://www.mindspore.cn/tutorials/experts/en/master/dataset/cache.html>`_ . 81 Default: ``None``, which means no cache is used. 82 83 Raises: 84 RuntimeError: If `dataset_files` are not valid or do not exist. 85 ValueError: If `field_delim` is invalid. 86 ValueError: If `num_parallel_workers` exceeds the max thread numbers. 87 RuntimeError: If `num_shards` is specified but `shard_id` is None. 88 RuntimeError: If `shard_id` is specified but `num_shards` is None. 89 ValueError: If `shard_id` is not in range of [0, `num_shards` ). 90 91 Examples: 92 >>> import mindspore.dataset as ds 93 >>> csv_dataset_dir = ["/path/to/csv_dataset_file"] # contains 1 or multiple csv files 94 >>> dataset = ds.CSVDataset(dataset_files=csv_dataset_dir, column_names=['col1', 'col2', 'col3', 'col4']) 95 """ 96 97 @check_csvdataset 98 def __init__(self, dataset_files, field_delim=',', column_defaults=None, column_names=None, num_samples=None, 99 num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, cache=None): 100 super().__init__(num_parallel_workers=num_parallel_workers, num_samples=num_samples, shuffle=shuffle, 101 num_shards=num_shards, shard_id=shard_id, cache=cache) 102 self.dataset_files = self._find_files(dataset_files) 103 self.dataset_files.sort() 104 self.field_delim = replace_none(field_delim, ',') 105 self.column_defaults = replace_none(column_defaults, []) 106 self.column_names = replace_none(column_names, []) 107 108 def parse(self, children=None): 109 return cde.CSVNode(self.dataset_files, self.field_delim, self.column_defaults, self.column_names, 110 self.num_samples, self.shuffle_flag, self.num_shards, self.shard_id) 111 112 113class MindDataset(MappableDataset, UnionBaseDataset): 114 """ 115 A source dataset that reads and parses MindRecord dataset. 116 117 The columns of generated dataset depend on the source MindRecord files. 118 119 Args: 120 dataset_files (Union[str, list[str]]): If dataset_file is a str, it represents for 121 a file name of one component of a mindrecord source, other files with identical source 122 in the same path will be found and loaded automatically. If dataset_file is a list, 123 it represents for a list of dataset files to be read directly. 124 columns_list (list[str], optional): List of columns to be read. Default: ``None`` , read all columns. 125 num_parallel_workers (int, optional): Number of worker threads to read the data. 126 Default: ``None`` , will use global default workers(8), it can be set 127 by :func:`mindspore.dataset.config.set_num_parallel_workers` . 128 shuffle (Union[bool, Shuffle], optional): Perform reshuffling of the data every epoch. 129 Default: ``None``, performs `mindspore.dataset.Shuffle.GLOBAL`. 130 Bool type and Shuffle enum are both supported to pass in. 131 If `shuffle` is ``False`` , no shuffling will be performed. 132 If `shuffle` is ``True`` , performs global shuffle. 133 There are three levels of shuffling, desired shuffle enum defined by :class:`mindspore.dataset.Shuffle` . 134 135 - ``Shuffle.GLOBAL`` : Global shuffle of all rows of data in dataset, same as setting shuffle to True. 136 137 - ``Shuffle.FILES`` : Shuffle the file sequence but keep the order of data within each file. 138 Not supported when the number of samples in the dataset is greater than 100 million. 139 140 - ``Shuffle.INFILE`` : Keep the file sequence the same but shuffle the data within each file. 141 Not supported when the number of samples in the dataset is greater than 100 million. 142 143 num_shards (int, optional): Number of shards that the dataset will be divided into. Default: ``None`` . 144 When this argument is specified, `num_samples` reflects the maximum sample number of per shard. 145 shard_id (int, optional): The shard ID within `num_shards` . Default: ``None`` . This 146 argument can only be specified when `num_shards` is also specified. 147 sampler (Sampler, optional): Object used to choose samples from the 148 dataset. Default: ``None`` , sampler is exclusive 149 with shuffle and block_reader. Support list: :class:`mindspore.dataset.SubsetRandomSampler`, 150 :class:`mindspore.dataset.PKSampler`, :class:`mindspore.dataset.RandomSampler`, 151 :class:`mindspore.dataset.SequentialSampler`, :class:`mindspore.dataset.DistributedSampler`. 152 padded_sample (dict, optional): Samples will be appended to dataset, where 153 keys are the same as columns_list. Default: ``None``. 154 num_padded (int, optional): Number of padding samples. Dataset size 155 plus num_padded should be divisible by num_shards. Default: ``None``. 156 num_samples (int, optional): The number of samples to be included in the dataset. 157 Default: ``None`` , all samples. 158 cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing. More details: 159 `Single-Node Data Cache <https://www.mindspore.cn/tutorials/experts/en/master/dataset/cache.html>`_ . 160 Default: ``None`` , which means no cache is used. 161 162 Raises: 163 ValueError: If dataset_files are not valid or do not exist. 164 ValueError: If `num_parallel_workers` exceeds the max thread numbers. 165 RuntimeError: If `num_shards` is specified but `shard_id` is None. 166 RuntimeError: If `shard_id` is specified but `num_shards` is None. 167 ValueError: If `shard_id` is not in range of [0, `num_shards` ). 168 169 Note: 170 - The parameters `num_samples` , `shuffle` , `num_shards` , `shard_id` can be used to control the sampler 171 used in the dataset, and their effects when combined with parameter `sampler` are as follows. 172 173 .. include:: mindspore.dataset.sampler.txt 174 175 Examples: 176 >>> import mindspore.dataset as ds 177 >>> mindrecord_files = ["/path/to/mind_dataset_file"] # contains 1 or multiple MindRecord files 178 >>> dataset = ds.MindDataset(dataset_files=mindrecord_files) 179 """ 180 181 def parse(self, children=None): 182 return cde.MindDataNode(self.dataset_files, self.columns_list, self.sampler, self.new_padded_sample, 183 self.num_padded, shuffle_to_shuffle_mode(self.shuffle_option)) 184 185 @check_minddataset 186 def __init__(self, dataset_files, columns_list=None, num_parallel_workers=None, shuffle=None, num_shards=None, 187 shard_id=None, sampler=None, padded_sample=None, num_padded=None, num_samples=None, cache=None): 188 super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples, 189 shuffle=shuffle_to_bool(shuffle), num_shards=num_shards, shard_id=shard_id, cache=cache) 190 if num_samples and shuffle in (Shuffle.FILES, Shuffle.INFILE): 191 raise ValueError("'Shuffle.FILES' or 'Shuffle.INFILE' and 'num_samples' " 192 "cannot be specified at the same time.") 193 self.shuffle_option = shuffle 194 self.load_dataset = True 195 if isinstance(dataset_files, list): 196 self.load_dataset = False 197 198 self.dataset_files = dataset_files 199 if platform.system().lower() == "windows": 200 if isinstance(dataset_files, list): 201 file_tuple = [] 202 for item in dataset_files: 203 item.replace("\\", "/") 204 file_tuple.append(item) 205 self.dataset_files = file_tuple 206 else: 207 self.dataset_files = dataset_files.replace("\\", "/") 208 209 # do decrypt & integrity check 210 if not isinstance(self.dataset_files, list): 211 if _get_enc_key() is not None or _get_hash_mode() is not None: 212 logger.warning("When a single mindrecord file which is generated by " + 213 "`mindspore.mindrecord.FileWriter` with `shard_num` > 1 is used as the input, " + 214 "enabling decryption/integrity check may fail. Please use file list as the input.") 215 216 # decrypt the data file and index file 217 index_file_name = self.dataset_files + ".db" 218 self.dataset_files = decrypt(self.dataset_files, _get_enc_key(), _get_dec_mode()) 219 decrypt(index_file_name, _get_enc_key(), _get_dec_mode()) 220 221 # verify integrity check 222 verify_file_hash(self.dataset_files) 223 verify_file_hash(self.dataset_files + ".db") 224 else: 225 file_tuple = [] 226 for item in self.dataset_files: 227 # decrypt the data file and index file 228 index_file_name = item + ".db" 229 decrypt_filename = decrypt(item, _get_enc_key(), _get_dec_mode()) 230 file_tuple.append(decrypt_filename) 231 decrypt(index_file_name, _get_enc_key(), _get_dec_mode()) 232 233 # verify integrity check 234 verify_file_hash(decrypt_filename) 235 verify_file_hash(decrypt_filename + ".db") 236 self.dataset_files = file_tuple 237 238 self.columns_list = replace_none(columns_list, []) 239 240 if sampler is not None: 241 if isinstance(sampler, ( 242 samplers.SubsetRandomSampler, samplers.SubsetSampler, samplers.PKSampler, 243 samplers.DistributedSampler, 244 samplers.RandomSampler, samplers.SequentialSampler)) is False: 245 raise ValueError("The sampler is not supported yet.") 246 247 self.padded_sample = padded_sample 248 self.num_padded = replace_none(num_padded, 0) 249 250 self.new_padded_sample = {} 251 if padded_sample: 252 for k, v in padded_sample.items(): 253 if isinstance(v, np.ndarray): 254 self.new_padded_sample[k] = v.tobytes() 255 else: 256 self.new_padded_sample[k] = v 257 258 259class TFRecordDataset(SourceDataset, UnionBaseDataset): 260 """ 261 A source dataset that reads and parses datasets stored on disk in TFData format. 262 263 The columns of generated dataset depend on the source TFRecord files. 264 265 Note: 266 'TFRecordDataset' is not support on Windows platform yet. 267 268 Args: 269 dataset_files (Union[str, list[str]]): String or list of files to be read or glob strings to search for a 270 pattern of files. The list will be sorted in lexicographical order. 271 schema (Union[str, Schema], optional): Data format policy, which specifies the data types and shapes of the data 272 column to be read. Both JSON file path and objects constructed by :class:`mindspore.dataset.Schema` are 273 acceptable. Default: ``None`` . 274 columns_list (list[str], optional): List of columns to be read. Default: ``None`` , read all columns. 275 num_samples (int, optional): The number of samples (rows) to be included in the dataset. Default: ``None`` . 276 When `num_shards` and `shard_id` are specified, it will be interpreted as number of rows per shard. 277 Processing priority for `num_samples` is as the following: 278 279 - If specify `num_samples` with value > 0, read `num_samples` samples. 280 281 - If no `num_samples` and specify numRows(parsed from `schema`) with value > 0, read numRows samples. 282 283 - If no `num_samples` and no `schema`, read the full dataset. 284 285 num_parallel_workers (int, optional): Number of worker threads to read the data. 286 Default: ``None`` , will use global default workers(8), it can be set 287 by :func:`mindspore.dataset.config.set_num_parallel_workers` . 288 shuffle (Union[bool, Shuffle], optional): Perform reshuffling of the data every epoch. 289 Default: ``Shuffle.GLOBAL`` . Bool type and Shuffle enum are both supported to pass in. 290 If `shuffle` is ``False``, no shuffling will be performed. 291 If `shuffle` is ``True``, perform global shuffle. 292 There are three levels of shuffling, desired shuffle enum defined by :class:`mindspore.dataset.Shuffle` . 293 294 - ``Shuffle.GLOBAL`` : Shuffle both the files and samples, same as setting `shuffle` to ``True``. 295 296 - ``Shuffle.FILES`` : Shuffle files only. 297 298 num_shards (int, optional): Number of shards that the dataset will be divided 299 into. Default: ``None`` . When this argument is specified, `num_samples` reflects 300 the maximum sample number per shard. 301 shard_id (int, optional): The shard ID within `num_shards` . Default: ``None`` . This 302 argument can only be specified when `num_shards` is also specified. 303 shard_equal_rows (bool, optional): Get equal rows for all shards. Default: ``False``. If `shard_equal_rows` 304 is False, the number of rows of each shard may not be equal, and may lead to a failure in distributed 305 training. When the number of samples per TFRecord file are not equal, it is suggested to set it to ``True``. 306 This argument should only be specified when `num_shards` is also specified. 307 When `compression_type` is not ``None``, and `num_samples` or numRows (parsed from `schema` ) is provided, 308 `shard_equal_rows` will be implied as ``True``. 309 cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing. More details: 310 `Single-Node Data Cache <https://www.mindspore.cn/tutorials/experts/en/master/dataset/cache.html>`_ . 311 Default: ``None`` , which means no cache is used. 312 compression_type (str, optional): The type of compression used for all files, must be either ``''``, 313 ``'GZIP'``, or ``'ZLIB'``. Default: ``None`` , as in empty string. It is highly recommended to 314 provide `num_samples` or numRows (parsed from `schema`) when `compression_type` is ``"GZIP"`` or 315 ``"ZLIB"`` to avoid performance degradation caused by multiple decompressions of the same file 316 to obtain the file size. 317 318 Raises: 319 ValueError: If dataset_files are not valid or do not exist. 320 ValueError: If `num_parallel_workers` exceeds the max thread numbers. 321 RuntimeError: If `num_shards` is specified but `shard_id` is None. 322 RuntimeError: If `shard_id` is specified but `num_shards` is None. 323 ValueError: If `shard_id` is not in range of [0, `num_shards` ). 324 ValueError: If `compression_type` is not ``''``, ``'GZIP'`` or ``'ZLIB'`` . 325 ValueError: If `compression_type` is provided, but the number of dataset files < `num_shards` . 326 ValueError: If `num_samples` < 0. 327 328 Examples: 329 >>> import mindspore.dataset as ds 330 >>> from mindspore import dtype as mstype 331 >>> 332 >>> tfrecord_dataset_dir = ["/path/to/tfrecord_dataset_file"] # contains 1 or multiple TFRecord files 333 >>> tfrecord_schema_file = "/path/to/tfrecord_schema_file" 334 >>> 335 >>> # 1) Get all rows from tfrecord_dataset_dir with no explicit schema. 336 >>> # The meta-data in the first row will be used as a schema. 337 >>> dataset = ds.TFRecordDataset(dataset_files=tfrecord_dataset_dir) 338 >>> 339 >>> # 2) Get all rows from tfrecord_dataset_dir with user-defined schema. 340 >>> schema = ds.Schema() 341 >>> schema.add_column(name='col_1d', de_type=mstype.int64, shape=[2]) 342 >>> dataset = ds.TFRecordDataset(dataset_files=tfrecord_dataset_dir, schema=schema) 343 >>> 344 >>> # 3) Get all rows from tfrecord_dataset_dir with the schema file. 345 >>> dataset = ds.TFRecordDataset(dataset_files=tfrecord_dataset_dir, schema=tfrecord_schema_file) 346 """ 347 348 @check_tfrecorddataset 349 def __init__(self, dataset_files, schema=None, columns_list=None, num_samples=None, num_parallel_workers=None, 350 shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, shard_equal_rows=False, 351 cache=None, compression_type=None): 352 super().__init__(num_parallel_workers=num_parallel_workers, num_samples=num_samples, shuffle=shuffle, 353 num_shards=num_shards, shard_id=shard_id, cache=cache) 354 if platform.system().lower() == "windows": 355 raise NotImplementedError("TFRecordDataset is not supported for windows.") 356 self.dataset_files = self._find_files(dataset_files) 357 self.dataset_files.sort() 358 359 self.schema = schema 360 self.columns_list = replace_none(columns_list, []) 361 self.shard_equal_rows = replace_none(shard_equal_rows, False) 362 self.compression_type = replace_none(compression_type, "") 363 364 # Only take numRows from schema when num_samples is not provided 365 if self.schema is not None and (self.num_samples is None or self.num_samples == 0): 366 self.num_samples = Schema.get_num_rows(self.schema) 367 368 if self.compression_type in ['ZLIB', 'GZIP'] and (self.num_samples is None or self.num_samples == 0): 369 logger.warning("Since compression_type is set, but neither num_samples nor numRows (from schema file) " + 370 "is provided, performance might be degraded.") 371 372 def parse(self, children=None): 373 schema = self.schema.cpp_schema if isinstance(self.schema, Schema) else self.schema 374 return cde.TFRecordNode(self.dataset_files, schema, self.columns_list, self.num_samples, self.shuffle_flag, 375 self.num_shards, self.shard_id, self.shard_equal_rows, self.compression_type) 376 377 378class OBSMindDataset(GeneratorDataset): 379 """ 380 381 A source dataset that reads and parses MindRecord dataset which stored in cloud storage 382 such as OBS, Minio or AWS S3. 383 384 The columns of generated dataset depend on the source MindRecord files. 385 386 Args: 387 dataset_files (list[str]): List of files in cloud storage to be read and file path is in 388 the format of s3://bucketName/objectKey. 389 server (str): Endpoint for accessing cloud storage. 390 If it's OBS Service of Huawei Cloud, the endpoint is 391 like ``<obs.cn-north-4.myhuaweicloud.com>`` (Region cn-north-4). 392 If it's Minio which starts locally, the endpoint is like ``<https://127.0.0.1:9000>``. 393 ak (str): The access key ID used to access the OBS data. 394 sk (str): The secret access key used to access the OBS data. 395 sync_obs_path (str): Remote dir path used for synchronization, users need to 396 create it on cloud storage in advance. Path is in the format of s3://bucketName/objectKey. 397 columns_list (list[str], optional): List of columns to be read. Default: ``None`` , read all columns. 398 shuffle (Union[bool, Shuffle], optional): Perform reshuffling of the data every epoch. 399 Default: ``Shuffle.GLOBAL``. Bool type and Shuffle enum are both supported to pass in. 400 If `shuffle` is ``False`` , no shuffling will be performed. 401 If `shuffle` is ``True`` , performs global shuffle. 402 There are three levels of shuffling, desired shuffle enum defined by :class:`mindspore.dataset.Shuffle` . 403 404 - ``Shuffle.GLOBAL`` : Global shuffle of all rows of data in dataset, same as setting shuffle to True. 405 406 - ``Shuffle.FILES`` : Shuffle the file sequence but keep the order of data within each file. 407 408 - ``Shuffle.INFILE`` : Keep the file sequence the same but shuffle the data within each file. 409 410 num_shards (int, optional): Number of shards that the dataset will be divided 411 into. Default: ``None`` . 412 shard_id (int, optional): The shard ID within num_shards. Default: ``None`` . This 413 argument can only be specified when `num_shards` is also specified. 414 shard_equal_rows (bool, optional): Get equal rows for all shards. Default: ``True``. If shard_equal_rows 415 is false, number of rows of each shard may be not equal, and may lead to a failure in distributed training. 416 When the number of samples of per MindRecord file are not equal, it is suggested to set to ``True``. 417 This argument should only be specified when `num_shards` is also specified. 418 419 Raises: 420 RuntimeError: If `sync_obs_path` do not exist. 421 ValueError: If `columns_list` is invalid. 422 RuntimeError: If `num_shards` is specified but `shard_id` is None. 423 RuntimeError: If `shard_id` is specified but `num_shards` is None. 424 ValueError: If `shard_id` is not in range of [0, `num_shards` ). 425 426 Note: 427 - It's necessary to create a synchronization directory on cloud storage in 428 advance which be defined by parameter: `sync_obs_path` . 429 - If training is offline(no cloud), it's recommended to set the 430 environment variable `BATCH_JOB_ID` . 431 - In distributed training, if there are multiple nodes(servers), all 8 432 devices must be used in each node(server). If there is only one 433 node(server), there is no such restriction. 434 435 Examples: 436 >>> import mindspore.dataset as ds 437 >>> # OBS 438 >>> bucket = "iris" # your obs bucket name 439 >>> # the bucket directory structure is similar to the following: 440 >>> # - imagenet21k 441 >>> # | - mr_imagenet21k_01 442 >>> # | - mr_imagenet21k_02 443 >>> # - sync_node 444 >>> dataset_obs_dir = ["s3://" + bucket + "/imagenet21k/mr_imagenet21k_01", 445 ... "s3://" + bucket + "/imagenet21k/mr_imagenet21k_02"] 446 >>> sync_obs_dir = "s3://" + bucket + "/sync_node" 447 >>> num_shards = 8 448 >>> shard_id = 0 449 >>> dataset = ds.OBSMindDataset(dataset_obs_dir, "obs.cn-north-4.myhuaweicloud.com", 450 ... "AK of OBS", "SK of OBS", 451 ... sync_obs_dir, shuffle=True, num_shards=num_shards, shard_id=shard_id) 452 """ 453 454 @check_obsminddataset 455 def __init__(self, dataset_files, server, ak, sk, sync_obs_path, 456 columns_list=None, 457 shuffle=Shuffle.GLOBAL, 458 num_shards=None, 459 shard_id=None, 460 shard_equal_rows=True): 461 462 from .obs.config_loader import config 463 config.AK = ak 464 config.SK = sk 465 config.SERVER = server 466 config.SYNC_OBS_PATH = sync_obs_path 467 468 if shuffle is not None and not isinstance(shuffle, (bool, Shuffle)): 469 raise TypeError("shuffle must be of boolean or enum of 'Shuffle' values like 'Shuffle.GLOBAL' or " 470 "'Shuffle.FILES'.") 471 472 self.num_shards = replace_none(num_shards, 1) 473 self.shard_id = replace_none(shard_id, 0) 474 self.shuffle = replace_none(shuffle, True) 475 476 dataset = MindRecordFromOBS(dataset_files, columns_list, shuffle, self.num_shards, self.shard_id, 477 shard_equal_rows, config.DATASET_LOCAL_PATH) 478 if not columns_list: 479 columns_list = dataset.get_col_names() 480 else: 481 full_columns_list = dataset.get_col_names() 482 if not set(columns_list).issubset(full_columns_list): 483 raise ValueError("columns_list: {} can not found in MindRecord fields: {}".format(columns_list, 484 full_columns_list)) 485 super().__init__(source=dataset, column_names=columns_list, num_shards=None, shard_id=None, shuffle=False) 486 487 488 def add_sampler(self, new_sampler): 489 raise NotImplementedError("add_sampler is not supported for OBSMindDataset.") 490 491 492 def use_sampler(self, new_sampler): 493 raise NotImplementedError("use_sampler is not supported for OBSMindDataset.") 494