1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7
8 * http://www.apache.org/licenses/LICENSE-2.0
9
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "minddata/dataset/engine/cache/cache_fbb.h"
17 namespace mindspore {
18 namespace dataset {
19 /// A private function used by SerializeTensorRowHeader to serialize each column in a tensor
20 /// \note Not to be called by outside world
21 /// \return Status object
SerializeOneTensorMeta(const std::shared_ptr<flatbuffers::FlatBufferBuilder> & fbb,const std::shared_ptr<Tensor> & ts_ptr,flatbuffers::Offset<TensorMetaMsg> * out_off)22 Status SerializeOneTensorMeta(const std::shared_ptr<flatbuffers::FlatBufferBuilder> &fbb,
23 const std::shared_ptr<Tensor> &ts_ptr, flatbuffers::Offset<TensorMetaMsg> *out_off) {
24 RETURN_UNEXPECTED_IF_NULL(out_off);
25 const Tensor *ts = ts_ptr.get();
26 auto shape_off = fbb->CreateVector(ts->shape().AsVector());
27 const auto ptr = ts->GetBuffer();
28 if (ptr == nullptr) {
29 RETURN_STATUS_UNEXPECTED("Tensor buffer is null");
30 }
31 auto src = ts->type().value();
32 TensorType dest;
33 #define CASE(t) \
34 case DataType::t: \
35 dest = TensorType::TensorType_##t; \
36 break
37 // Map the type to fill in the flat buffer.
38 switch (src) {
39 CASE(DE_BOOL);
40 CASE(DE_INT8);
41 CASE(DE_UINT8);
42 CASE(DE_INT16);
43 CASE(DE_UINT16);
44 CASE(DE_INT32);
45 CASE(DE_UINT32);
46 CASE(DE_INT64);
47 CASE(DE_UINT64);
48 CASE(DE_FLOAT16);
49 CASE(DE_FLOAT32);
50 CASE(DE_FLOAT64);
51 CASE(DE_STRING);
52 CASE(DE_BYTES);
53 default:
54 MS_LOG(ERROR) << "Unknown tensor. Dumping content:\n" << *ts;
55 RETURN_STATUS_UNEXPECTED("Unknown type");
56 }
57 #undef CASE
58
59 TensorMetaMsgBuilder ts_builder(*fbb);
60 ts_builder.add_dims(shape_off);
61 ts_builder.add_type(dest);
62 auto ts_off = ts_builder.Finish();
63 *out_off = ts_off;
64 return Status::OK();
65 }
66
SerializeTensorRowHeader(const TensorRow & row,std::shared_ptr<flatbuffers::FlatBufferBuilder> * out_fbb)67 Status SerializeTensorRowHeader(const TensorRow &row, std::shared_ptr<flatbuffers::FlatBufferBuilder> *out_fbb) {
68 RETURN_UNEXPECTED_IF_NULL(out_fbb);
69 auto fbb = std::make_shared<flatbuffers::FlatBufferBuilder>();
70 try {
71 fbb = std::make_shared<flatbuffers::FlatBufferBuilder>();
72 std::vector<flatbuffers::Offset<TensorMetaMsg>> v;
73 std::vector<int64_t> tensor_sz;
74 v.reserve(row.size());
75 tensor_sz.reserve(row.size());
76 // We will go through each column in the row.
77 for (const std::shared_ptr<Tensor> &ts_ptr : row) {
78 flatbuffers::Offset<TensorMetaMsg> ts_off;
79 RETURN_IF_NOT_OK(SerializeOneTensorMeta(fbb, ts_ptr, &ts_off));
80 v.push_back(ts_off);
81 tensor_sz.push_back(ts_ptr->SizeInBytes());
82 }
83 auto column_off = fbb->CreateVector(v);
84 auto data_sz_off = fbb->CreateVector(tensor_sz);
85 TensorRowHeaderMsgBuilder row_builder(*fbb);
86 row_builder.add_column(column_off);
87 row_builder.add_data_sz(data_sz_off);
88 // Pass the row_id even if it may not be known.
89 row_builder.add_row_id(row.getId());
90 row_builder.add_size_of_this(-1); // fill in later after we call Finish.
91 auto out = row_builder.Finish();
92 fbb->Finish(out);
93 // Now go back to fill in size_of_this in the flat buffer.
94 auto msg = GetMutableTensorRowHeaderMsg(fbb->GetBufferPointer());
95 auto success = msg->mutate_size_of_this(fbb->GetSize());
96 if (!success) {
97 RETURN_STATUS_UNEXPECTED("Unable to set size_of_this");
98 }
99 (*out_fbb) = std::move(fbb);
100 return Status::OK();
101 } catch (const std::bad_alloc &e) {
102 RETURN_STATUS_OOM("Out of memory.");
103 }
104 }
105
RestoreOneTensor(const TensorMetaMsg * col_ts,const ReadableSlice & data,std::shared_ptr<Tensor> * out)106 Status RestoreOneTensor(const TensorMetaMsg *col_ts, const ReadableSlice &data, std::shared_ptr<Tensor> *out) {
107 RETURN_UNEXPECTED_IF_NULL(col_ts);
108 auto shape_in = col_ts->dims();
109 auto type_in = col_ts->type();
110 std::vector<dsize_t> v;
111 v.reserve(shape_in->size());
112 v.assign(shape_in->begin(), shape_in->end());
113 TensorShape shape(v);
114 DataType::Type dest = DataType::DE_UNKNOWN;
115 #define CASE(t) \
116 case TensorType_##t: \
117 dest = DataType::Type::t; \
118 break
119
120 switch (type_in) {
121 CASE(DE_BOOL);
122 CASE(DE_INT8);
123 CASE(DE_UINT8);
124 CASE(DE_INT16);
125 CASE(DE_UINT16);
126 CASE(DE_INT32);
127 CASE(DE_UINT32);
128 CASE(DE_INT64);
129 CASE(DE_UINT64);
130 CASE(DE_FLOAT16);
131 CASE(DE_FLOAT32);
132 CASE(DE_FLOAT64);
133 CASE(DE_STRING);
134 CASE(DE_BYTES);
135 }
136 #undef CASE
137
138 DataType type(dest);
139 std::shared_ptr<Tensor> ts;
140 RETURN_IF_NOT_OK(
141 Tensor::CreateFromMemory(shape, type, static_cast<const unsigned char *>(data.GetPointer()), data.GetSize(), &ts));
142 // Next we restore the real data which can be embedded or stored separately.
143 if (ts->SizeInBytes() != data.GetSize()) {
144 MS_LOG(ERROR) << "Unexpected length. Read " << data.GetSize() << ". Expected " << ts->SizeInBytes() << ".\n"
145 << "Dumping tensor\n"
146 << *ts << "\n";
147 RETURN_STATUS_UNEXPECTED("Length mismatch. See log file for details.");
148 }
149 *out = std::move(ts);
150 return Status::OK();
151 }
152 } // namespace dataset
153 } // namespace mindspore
154