• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3 
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7 
8  * http://www.apache.org/licenses/LICENSE-2.0
9 
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15 */
16 #include "minddata/dataset/engine/cache/cache_fbb.h"
17 namespace mindspore {
18 namespace dataset {
19 /// A private function used by SerializeTensorRowHeader to serialize each column in a tensor
20 /// \note Not to be called by outside world
21 /// \return Status object
SerializeOneTensorMeta(const std::shared_ptr<flatbuffers::FlatBufferBuilder> & fbb,const std::shared_ptr<Tensor> & ts_ptr,flatbuffers::Offset<TensorMetaMsg> * out_off)22 Status SerializeOneTensorMeta(const std::shared_ptr<flatbuffers::FlatBufferBuilder> &fbb,
23                               const std::shared_ptr<Tensor> &ts_ptr, flatbuffers::Offset<TensorMetaMsg> *out_off) {
24   RETURN_UNEXPECTED_IF_NULL(out_off);
25   const Tensor *ts = ts_ptr.get();
26   auto shape_off = fbb->CreateVector(ts->shape().AsVector());
27   const auto ptr = ts->GetBuffer();
28   if (ptr == nullptr) {
29     RETURN_STATUS_UNEXPECTED("Tensor buffer is null");
30   }
31   auto src = ts->type().value();
32   TensorType dest;
33 #define CASE(t)                        \
34   case DataType::t:                    \
35     dest = TensorType::TensorType_##t; \
36     break
37   // Map the type to fill in the flat buffer.
38   switch (src) {
39     CASE(DE_BOOL);
40     CASE(DE_INT8);
41     CASE(DE_UINT8);
42     CASE(DE_INT16);
43     CASE(DE_UINT16);
44     CASE(DE_INT32);
45     CASE(DE_UINT32);
46     CASE(DE_INT64);
47     CASE(DE_UINT64);
48     CASE(DE_FLOAT16);
49     CASE(DE_FLOAT32);
50     CASE(DE_FLOAT64);
51     CASE(DE_STRING);
52     CASE(DE_BYTES);
53     default:
54       MS_LOG(ERROR) << "Unknown tensor. Dumping content:\n" << *ts;
55       RETURN_STATUS_UNEXPECTED("Unknown type");
56   }
57 #undef CASE
58 
59   TensorMetaMsgBuilder ts_builder(*fbb);
60   ts_builder.add_dims(shape_off);
61   ts_builder.add_type(dest);
62   auto ts_off = ts_builder.Finish();
63   *out_off = ts_off;
64   return Status::OK();
65 }
66 
SerializeTensorRowHeader(const TensorRow & row,std::shared_ptr<flatbuffers::FlatBufferBuilder> * out_fbb)67 Status SerializeTensorRowHeader(const TensorRow &row, std::shared_ptr<flatbuffers::FlatBufferBuilder> *out_fbb) {
68   RETURN_UNEXPECTED_IF_NULL(out_fbb);
69   auto fbb = std::make_shared<flatbuffers::FlatBufferBuilder>();
70   try {
71     fbb = std::make_shared<flatbuffers::FlatBufferBuilder>();
72     std::vector<flatbuffers::Offset<TensorMetaMsg>> v;
73     std::vector<int64_t> tensor_sz;
74     v.reserve(row.size());
75     tensor_sz.reserve(row.size());
76     // We will go through each column in the row.
77     for (const std::shared_ptr<Tensor> &ts_ptr : row) {
78       flatbuffers::Offset<TensorMetaMsg> ts_off;
79       RETURN_IF_NOT_OK(SerializeOneTensorMeta(fbb, ts_ptr, &ts_off));
80       v.push_back(ts_off);
81       tensor_sz.push_back(ts_ptr->SizeInBytes());
82     }
83     auto column_off = fbb->CreateVector(v);
84     auto data_sz_off = fbb->CreateVector(tensor_sz);
85     TensorRowHeaderMsgBuilder row_builder(*fbb);
86     row_builder.add_column(column_off);
87     row_builder.add_data_sz(data_sz_off);
88     // Pass the row_id even if it may not be known.
89     row_builder.add_row_id(row.getId());
90     row_builder.add_size_of_this(-1);  // fill in later after we call Finish.
91     auto out = row_builder.Finish();
92     fbb->Finish(out);
93     // Now go back to fill in size_of_this in the flat buffer.
94     auto msg = GetMutableTensorRowHeaderMsg(fbb->GetBufferPointer());
95     auto success = msg->mutate_size_of_this(fbb->GetSize());
96     if (!success) {
97       RETURN_STATUS_UNEXPECTED("Unable to set size_of_this");
98     }
99     (*out_fbb) = std::move(fbb);
100     return Status::OK();
101   } catch (const std::bad_alloc &e) {
102     RETURN_STATUS_OOM("Out of memory.");
103   }
104 }
105 
RestoreOneTensor(const TensorMetaMsg * col_ts,const ReadableSlice & data,std::shared_ptr<Tensor> * out)106 Status RestoreOneTensor(const TensorMetaMsg *col_ts, const ReadableSlice &data, std::shared_ptr<Tensor> *out) {
107   RETURN_UNEXPECTED_IF_NULL(col_ts);
108   auto shape_in = col_ts->dims();
109   auto type_in = col_ts->type();
110   std::vector<dsize_t> v;
111   v.reserve(shape_in->size());
112   v.assign(shape_in->begin(), shape_in->end());
113   TensorShape shape(v);
114   DataType::Type dest = DataType::DE_UNKNOWN;
115 #define CASE(t)               \
116   case TensorType_##t:        \
117     dest = DataType::Type::t; \
118     break
119 
120   switch (type_in) {
121     CASE(DE_BOOL);
122     CASE(DE_INT8);
123     CASE(DE_UINT8);
124     CASE(DE_INT16);
125     CASE(DE_UINT16);
126     CASE(DE_INT32);
127     CASE(DE_UINT32);
128     CASE(DE_INT64);
129     CASE(DE_UINT64);
130     CASE(DE_FLOAT16);
131     CASE(DE_FLOAT32);
132     CASE(DE_FLOAT64);
133     CASE(DE_STRING);
134     CASE(DE_BYTES);
135   }
136 #undef CASE
137 
138   DataType type(dest);
139   std::shared_ptr<Tensor> ts;
140   RETURN_IF_NOT_OK(
141     Tensor::CreateFromMemory(shape, type, static_cast<const unsigned char *>(data.GetPointer()), data.GetSize(), &ts));
142   // Next we restore the real data which can be embedded or stored separately.
143   if (ts->SizeInBytes() != data.GetSize()) {
144     MS_LOG(ERROR) << "Unexpected length. Read " << data.GetSize() << ". Expected " << ts->SizeInBytes() << ".\n"
145                   << "Dumping tensor\n"
146                   << *ts << "\n";
147     RETURN_STATUS_UNEXPECTED("Length mismatch. See log file for details.");
148   }
149   *out = std::move(ts);
150   return Status::OK();
151 }
152 }  // namespace dataset
153 }  // namespace mindspore
154