1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7
8 * http://www.apache.org/licenses/LICENSE-2.0
9
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "minddata/dataset/engine/cache/cache_fbb.h"
17 namespace mindspore {
18 namespace dataset {
19 /// A private function used by SerializeTensorRowHeader to serialize each column in a tensor
20 /// \note Not to be called by outside world
21 /// \return Status object
SerializeOneTensorMeta(const std::shared_ptr<flatbuffers::FlatBufferBuilder> & fbb,const std::shared_ptr<Tensor> & ts_ptr,flatbuffers::Offset<TensorMetaMsg> * out_off)22 Status SerializeOneTensorMeta(const std::shared_ptr<flatbuffers::FlatBufferBuilder> &fbb,
23 const std::shared_ptr<Tensor> &ts_ptr, flatbuffers::Offset<TensorMetaMsg> *out_off) {
24 RETURN_UNEXPECTED_IF_NULL(out_off);
25 const Tensor *ts = ts_ptr.get();
26 auto shape_off = fbb->CreateVector(ts->shape().AsVector());
27 const auto ptr = ts->GetBuffer();
28 if (ptr == nullptr) {
29 RETURN_STATUS_UNEXPECTED("Tensor buffer is null");
30 }
31 auto src = ts->type().value();
32 TensorType dest;
33 #define CASE(t) \
34 case DataType::t: \
35 dest = TensorType::TensorType_##t; \
36 break
37 // Map the type to fill in the flat buffer.
38 switch (src) {
39 CASE(DE_BOOL);
40 CASE(DE_INT8);
41 CASE(DE_UINT8);
42 CASE(DE_INT16);
43 CASE(DE_UINT16);
44 CASE(DE_INT32);
45 CASE(DE_UINT32);
46 CASE(DE_INT64);
47 CASE(DE_UINT64);
48 CASE(DE_FLOAT16);
49 CASE(DE_FLOAT32);
50 CASE(DE_FLOAT64);
51 CASE(DE_STRING);
52 default:
53 MS_LOG(ERROR) << "Unknown tensor. Dumping content:\n" << *ts;
54 RETURN_STATUS_UNEXPECTED("Unknown type");
55 }
56 #undef CASE
57
58 TensorMetaMsgBuilder ts_builder(*fbb);
59 ts_builder.add_dims(shape_off);
60 ts_builder.add_type(dest);
61 auto ts_off = ts_builder.Finish();
62 *out_off = ts_off;
63 return Status::OK();
64 }
65
SerializeTensorRowHeader(const TensorRow & row,std::shared_ptr<flatbuffers::FlatBufferBuilder> * out_fbb)66 Status SerializeTensorRowHeader(const TensorRow &row, std::shared_ptr<flatbuffers::FlatBufferBuilder> *out_fbb) {
67 RETURN_UNEXPECTED_IF_NULL(out_fbb);
68 auto fbb = std::make_shared<flatbuffers::FlatBufferBuilder>();
69 try {
70 fbb = std::make_shared<flatbuffers::FlatBufferBuilder>();
71 std::vector<flatbuffers::Offset<TensorMetaMsg>> v;
72 std::vector<int64_t> tensor_sz;
73 v.reserve(row.size());
74 tensor_sz.reserve(row.size());
75 // We will go through each column in the row.
76 for (const std::shared_ptr<Tensor> &ts_ptr : row) {
77 flatbuffers::Offset<TensorMetaMsg> ts_off;
78 RETURN_IF_NOT_OK(SerializeOneTensorMeta(fbb, ts_ptr, &ts_off));
79 v.push_back(ts_off);
80 tensor_sz.push_back(ts_ptr->SizeInBytes());
81 }
82 auto column_off = fbb->CreateVector(v);
83 auto data_sz_off = fbb->CreateVector(tensor_sz);
84 TensorRowHeaderMsgBuilder row_builder(*fbb);
85 row_builder.add_column(column_off);
86 row_builder.add_data_sz(data_sz_off);
87 // Pass the row_id even if it may not be known.
88 row_builder.add_row_id(row.getId());
89 row_builder.add_size_of_this(-1); // fill in later after we call Finish.
90 auto out = row_builder.Finish();
91 fbb->Finish(out);
92 // Now go back to fill in size_of_this in the flat buffer.
93 auto msg = GetMutableTensorRowHeaderMsg(fbb->GetBufferPointer());
94 auto success = msg->mutate_size_of_this(fbb->GetSize());
95 if (!success) {
96 RETURN_STATUS_UNEXPECTED("Unable to set size_of_this");
97 }
98 (*out_fbb) = std::move(fbb);
99 return Status::OK();
100 } catch (const std::bad_alloc &e) {
101 return Status(StatusCode::kMDOutOfMemory, __LINE__, __FILE__);
102 }
103 }
104
RestoreOneTensor(const TensorMetaMsg * col_ts,const ReadableSlice & data,std::shared_ptr<Tensor> * out)105 Status RestoreOneTensor(const TensorMetaMsg *col_ts, const ReadableSlice &data, std::shared_ptr<Tensor> *out) {
106 RETURN_UNEXPECTED_IF_NULL(col_ts);
107 auto shape_in = col_ts->dims();
108 auto type_in = col_ts->type();
109 std::vector<dsize_t> v;
110 v.reserve(shape_in->size());
111 v.assign(shape_in->begin(), shape_in->end());
112 TensorShape shape(v);
113 DataType::Type dest = DataType::DE_UNKNOWN;
114 #define CASE(t) \
115 case TensorType_##t: \
116 dest = DataType::Type::t; \
117 break
118
119 switch (type_in) {
120 CASE(DE_BOOL);
121 CASE(DE_INT8);
122 CASE(DE_UINT8);
123 CASE(DE_INT16);
124 CASE(DE_UINT16);
125 CASE(DE_INT32);
126 CASE(DE_UINT32);
127 CASE(DE_INT64);
128 CASE(DE_UINT64);
129 CASE(DE_FLOAT16);
130 CASE(DE_FLOAT32);
131 CASE(DE_FLOAT64);
132 CASE(DE_STRING);
133 }
134 #undef CASE
135
136 DataType type(dest);
137 std::shared_ptr<Tensor> ts;
138 RETURN_IF_NOT_OK(
139 Tensor::CreateFromMemory(shape, type, static_cast<const unsigned char *>(data.GetPointer()), data.GetSize(), &ts));
140 // Next we restore the real data which can be embedded or stored separately.
141 if (ts->SizeInBytes() != data.GetSize()) {
142 MS_LOG(ERROR) << "Unexpected length. Read " << data.GetSize() << ". Expected " << ts->SizeInBytes() << ".\n"
143 << "Dumping tensor\n"
144 << *ts << "\n";
145 RETURN_STATUS_UNEXPECTED("Length mismatch. See log file for details.");
146 }
147 *out = std::move(ts);
148 return Status::OK();
149 }
150 } // namespace dataset
151 } // namespace mindspore
152