• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3 
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7 
8  * http://www.apache.org/licenses/LICENSE-2.0
9 
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15 */
16 #include "minddata/dataset/engine/cache/cache_fbb.h"
17 namespace mindspore {
18 namespace dataset {
19 /// A private function used by SerializeTensorRowHeader to serialize each column in a tensor
20 /// \note Not to be called by outside world
21 /// \return Status object
SerializeOneTensorMeta(const std::shared_ptr<flatbuffers::FlatBufferBuilder> & fbb,const std::shared_ptr<Tensor> & ts_ptr,flatbuffers::Offset<TensorMetaMsg> * out_off)22 Status SerializeOneTensorMeta(const std::shared_ptr<flatbuffers::FlatBufferBuilder> &fbb,
23                               const std::shared_ptr<Tensor> &ts_ptr, flatbuffers::Offset<TensorMetaMsg> *out_off) {
24   RETURN_UNEXPECTED_IF_NULL(out_off);
25   const Tensor *ts = ts_ptr.get();
26   auto shape_off = fbb->CreateVector(ts->shape().AsVector());
27   const auto ptr = ts->GetBuffer();
28   if (ptr == nullptr) {
29     RETURN_STATUS_UNEXPECTED("Tensor buffer is null");
30   }
31   auto src = ts->type().value();
32   TensorType dest;
33 #define CASE(t)                        \
34   case DataType::t:                    \
35     dest = TensorType::TensorType_##t; \
36     break
37   // Map the type to fill in the flat buffer.
38   switch (src) {
39     CASE(DE_BOOL);
40     CASE(DE_INT8);
41     CASE(DE_UINT8);
42     CASE(DE_INT16);
43     CASE(DE_UINT16);
44     CASE(DE_INT32);
45     CASE(DE_UINT32);
46     CASE(DE_INT64);
47     CASE(DE_UINT64);
48     CASE(DE_FLOAT16);
49     CASE(DE_FLOAT32);
50     CASE(DE_FLOAT64);
51     CASE(DE_STRING);
52     default:
53       MS_LOG(ERROR) << "Unknown tensor. Dumping content:\n" << *ts;
54       RETURN_STATUS_UNEXPECTED("Unknown type");
55   }
56 #undef CASE
57 
58   TensorMetaMsgBuilder ts_builder(*fbb);
59   ts_builder.add_dims(shape_off);
60   ts_builder.add_type(dest);
61   auto ts_off = ts_builder.Finish();
62   *out_off = ts_off;
63   return Status::OK();
64 }
65 
SerializeTensorRowHeader(const TensorRow & row,std::shared_ptr<flatbuffers::FlatBufferBuilder> * out_fbb)66 Status SerializeTensorRowHeader(const TensorRow &row, std::shared_ptr<flatbuffers::FlatBufferBuilder> *out_fbb) {
67   RETURN_UNEXPECTED_IF_NULL(out_fbb);
68   auto fbb = std::make_shared<flatbuffers::FlatBufferBuilder>();
69   try {
70     fbb = std::make_shared<flatbuffers::FlatBufferBuilder>();
71     std::vector<flatbuffers::Offset<TensorMetaMsg>> v;
72     std::vector<int64_t> tensor_sz;
73     v.reserve(row.size());
74     tensor_sz.reserve(row.size());
75     // We will go through each column in the row.
76     for (const std::shared_ptr<Tensor> &ts_ptr : row) {
77       flatbuffers::Offset<TensorMetaMsg> ts_off;
78       RETURN_IF_NOT_OK(SerializeOneTensorMeta(fbb, ts_ptr, &ts_off));
79       v.push_back(ts_off);
80       tensor_sz.push_back(ts_ptr->SizeInBytes());
81     }
82     auto column_off = fbb->CreateVector(v);
83     auto data_sz_off = fbb->CreateVector(tensor_sz);
84     TensorRowHeaderMsgBuilder row_builder(*fbb);
85     row_builder.add_column(column_off);
86     row_builder.add_data_sz(data_sz_off);
87     // Pass the row_id even if it may not be known.
88     row_builder.add_row_id(row.getId());
89     row_builder.add_size_of_this(-1);  // fill in later after we call Finish.
90     auto out = row_builder.Finish();
91     fbb->Finish(out);
92     // Now go back to fill in size_of_this in the flat buffer.
93     auto msg = GetMutableTensorRowHeaderMsg(fbb->GetBufferPointer());
94     auto success = msg->mutate_size_of_this(fbb->GetSize());
95     if (!success) {
96       RETURN_STATUS_UNEXPECTED("Unable to set size_of_this");
97     }
98     (*out_fbb) = std::move(fbb);
99     return Status::OK();
100   } catch (const std::bad_alloc &e) {
101     return Status(StatusCode::kMDOutOfMemory, __LINE__, __FILE__);
102   }
103 }
104 
RestoreOneTensor(const TensorMetaMsg * col_ts,const ReadableSlice & data,std::shared_ptr<Tensor> * out)105 Status RestoreOneTensor(const TensorMetaMsg *col_ts, const ReadableSlice &data, std::shared_ptr<Tensor> *out) {
106   RETURN_UNEXPECTED_IF_NULL(col_ts);
107   auto shape_in = col_ts->dims();
108   auto type_in = col_ts->type();
109   std::vector<dsize_t> v;
110   v.reserve(shape_in->size());
111   v.assign(shape_in->begin(), shape_in->end());
112   TensorShape shape(v);
113   DataType::Type dest = DataType::DE_UNKNOWN;
114 #define CASE(t)               \
115   case TensorType_##t:        \
116     dest = DataType::Type::t; \
117     break
118 
119   switch (type_in) {
120     CASE(DE_BOOL);
121     CASE(DE_INT8);
122     CASE(DE_UINT8);
123     CASE(DE_INT16);
124     CASE(DE_UINT16);
125     CASE(DE_INT32);
126     CASE(DE_UINT32);
127     CASE(DE_INT64);
128     CASE(DE_UINT64);
129     CASE(DE_FLOAT16);
130     CASE(DE_FLOAT32);
131     CASE(DE_FLOAT64);
132     CASE(DE_STRING);
133   }
134 #undef CASE
135 
136   DataType type(dest);
137   std::shared_ptr<Tensor> ts;
138   RETURN_IF_NOT_OK(
139     Tensor::CreateFromMemory(shape, type, static_cast<const unsigned char *>(data.GetPointer()), data.GetSize(), &ts));
140   // Next we restore the real data which can be embedded or stored separately.
141   if (ts->SizeInBytes() != data.GetSize()) {
142     MS_LOG(ERROR) << "Unexpected length. Read " << data.GetSize() << ". Expected " << ts->SizeInBytes() << ".\n"
143                   << "Dumping tensor\n"
144                   << *ts << "\n";
145     RETURN_STATUS_UNEXPECTED("Length mismatch. See log file for details.");
146   }
147   *out = std::move(ts);
148   return Status::OK();
149 }
150 }  // namespace dataset
151 }  // namespace mindspore
152