1 /** 2 * Copyright 2021-2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_DISTRIBUTED_PERSISTENT_STORAGE_STORAGE_H_ 18 #define MINDSPORE_CCSRC_DISTRIBUTED_PERSISTENT_STORAGE_STORAGE_H_ 19 20 #include <map> 21 #include <string> 22 #include <vector> 23 #include <tuple> 24 #include <utility> 25 #include <memory> 26 27 namespace mindspore { 28 namespace distributed { 29 namespace storage { 30 // InputData consists of shape, const buffer pointer and size 31 using InputData = std::tuple<std::vector<int>, const void *, size_t>; 32 // OutputData consists of buffer pointer and size 33 using OutputData = std::pair<void *, size_t>; 34 // DirtyInfo is used to indicate the part of the Tensor that needs to be rewritten to storage, 35 using DirtyInfo = std::vector<int>; 36 37 // Data memory buffer and buffer length. 38 struct DataWithLen { 39 void *data_{nullptr}; 40 size_t data_len_{0}; 41 }; 42 43 // Const data memory buffer and buffer length. 44 struct ConstDataWithLen { 45 const void *data_{nullptr}; 46 size_t data_len_{0}; 47 }; 48 49 // This Class provides upper-layer interfaces for persistent storage. 50 template <typename KeyType = int32_t, typename ValueType = float> 51 class StorageBase { 52 public: 53 StorageBase() = default; 54 virtual ~StorageBase() = default; 55 56 // Initialize the storage module and allocate necessary resources. Initialize()57 virtual void Initialize() {} 58 59 // Release the resource the storage module allocates. Finalize()60 virtual void Finalize() {} 61 62 // Write input tensor to storage medium or memory buffer. 63 // The parameter dirty_info indicates that the part of the Tensor that needs to be rewritten to storage, 64 // for example, some rows of embedding table need to be rewritten to storage, the dirty_info should contain these row 65 // numbers. Write(const InputData & input,const DirtyInfo & dirty_info)66 virtual void Write(const InputData &input, const DirtyInfo &dirty_info) {} 67 68 // Write input to storage medium or memory buffer, only support the input composed of multiple tensors with same shape 69 // and data type and using same dirty info at present. 70 // The parameter dirty_info indicates that the part of the Tensor that needs to be rewritten to storage. Write(const std::vector<InputData> & input,const DirtyInfo & dirty_info)71 virtual void Write(const std::vector<InputData> &input, const DirtyInfo &dirty_info) {} 72 73 // Write key-value pairs data into persistent storage. 74 // Parameter[in] `keys`: The keys need to write, containing data pointer and data buffer length. 75 // Parameter[in] `values`: The values corresponding to keys need to write, containing data pointer and data buffer 76 // length. Write(const ConstDataWithLen & keys,const ConstDataWithLen & values)77 virtual void Write(const ConstDataWithLen &keys, const ConstDataWithLen &values) {} 78 79 // Read data from the storage medium or memory buffer and merge them into contiguous memory. Read(const OutputData & output)80 virtual void Read(const OutputData &output) {} 81 82 // Read data from the storage medium or memory buffer and merge them into contiguous memory for multiple tensors. Read(const std::vector<OutputData> & outputs)83 virtual void Read(const std::vector<OutputData> &outputs) {} 84 85 // Read key-value pairs' values data from persistent storage. 86 // Parameter[in] `keys`: The keys whose values need to read, containing data pointer and data buffer length. 87 // Parameter[out] `values`: The values corresponding to keys need to read, containing data pointer and data buffer 88 // length. Read(const ConstDataWithLen & keys,const DataWithLen & values)89 virtual void Read(const ConstDataWithLen &keys, const DataWithLen &values) {} 90 91 // Dump all keys of all key-value pairs in storage. GetAllKeys()92 virtual std::unique_ptr<std::vector<KeyType>> GetAllKeys() const { return nullptr; } 93 }; 94 } // namespace storage 95 } // namespace distributed 96 } // namespace mindspore 97 98 #endif // MINDSPORE_CCSRC_DISTRIBUTED_PERSISTENT_STORAGE_STORAGE_H_ 99