• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "minddata/mindrecord/include/shard_schema.h"
18 #include "utils/ms_utils.h"
19 
20 using mindspore::LogStream;
21 using mindspore::ExceptionType::NoExceptionType;
22 using mindspore::MsLogLevel::ERROR;
23 
24 namespace mindspore {
25 namespace mindrecord {
Build(std::string desc,const json & schema)26 std::shared_ptr<Schema> Schema::Build(std::string desc, const json &schema) {
27   // validate check
28   if (!Validate(schema)) {
29     return nullptr;
30   }
31 
32   std::vector<std::string> blob_fields = PopulateBlobFields(schema);
33   Schema object_schema;
34   object_schema.desc_ = std::move(desc);
35   object_schema.blob_fields_ = std::move(blob_fields);
36   object_schema.schema_ = schema;
37   object_schema.schema_id_ = -1;
38   return std::make_shared<Schema>(object_schema);
39 }
40 
GetDesc() const41 std::string Schema::GetDesc() const { return desc_; }
42 
GetSchema() const43 json Schema::GetSchema() const {
44   json str_schema;
45   str_schema["desc"] = desc_;
46   str_schema["schema"] = schema_;
47   str_schema["blob_fields"] = blob_fields_;
48   return str_schema;
49 }
50 
SetSchemaID(int64_t id)51 void Schema::SetSchemaID(int64_t id) { schema_id_ = id; }
52 
GetSchemaID() const53 int64_t Schema::GetSchemaID() const { return schema_id_; }
54 
GetBlobFields() const55 std::vector<std::string> Schema::GetBlobFields() const { return blob_fields_; }
56 
PopulateBlobFields(json schema)57 std::vector<std::string> Schema::PopulateBlobFields(json schema) {
58   std::vector<std::string> blob_fields;
59   for (json::iterator it = schema.begin(); it != schema.end(); ++it) {
60     json it_value = it.value();
61     if ((it_value.size() == kInt2 && it_value.find("shape") != it_value.end()) || it_value["type"] == "bytes") {
62       blob_fields.emplace_back(it.key());
63     }
64   }
65   return blob_fields;
66 }
67 
ValidateNumberShape(const json & it_value)68 bool Schema::ValidateNumberShape(const json &it_value) {
69   if (it_value.find("shape") == it_value.end()) {
70     MS_LOG(ERROR) << "Invalid data, 'shape' object can not found in " << it_value.dump();
71     return false;
72   }
73 
74   auto shape = it_value["shape"];
75   if (!shape.is_array()) {
76     MS_LOG(ERROR) << "Invalid data, shape [" << it_value["shape"].dump() << "] is invalid.";
77     return false;
78   }
79 
80   int num_negtive_one = 0;
81   for (const auto &i : shape) {
82     if (i == 0 || i < -1) {
83       MS_LOG(ERROR) << "Invalid data, shape [" << it_value["shape"].dump() << "]dimension is invalid.";
84       return false;
85     }
86     if (i == -1) {
87       num_negtive_one++;
88     }
89   }
90 
91   if (num_negtive_one > 1) {
92     MS_LOG(ERROR) << "Invalid data, shape [" << it_value["shape"].dump()
93                   << "] have more than 1 variable dimension(-1).";
94     return false;
95   }
96 
97   return true;
98 }
99 
Validate(json schema)100 bool Schema::Validate(json schema) {
101   if (schema.size() == kInt0) {
102     MS_LOG(ERROR) << "Invalid data, schema is empty.";
103     return false;
104   }
105 
106   for (json::iterator it = schema.begin(); it != schema.end(); ++it) {
107     // make sure schema key name must be composed of '0-9' or 'a-z' or 'A-Z' or '_'
108     if (!ValidateFieldName(it.key())) {
109       MS_LOG(ERROR) << "Invalid data, field [" << it.key()
110                     << "] in schema is not composed of '0-9' or 'a-z' or 'A-Z' or '_'.";
111       return false;
112     }
113 
114     json it_value = it.value();
115     if (it_value.find("type") == it_value.end()) {
116       MS_LOG(ERROR) << "Invalid data, 'type' object can not found in field [" << it_value.dump() << "].";
117       return false;
118     }
119 
120     if (kFieldTypeSet.find(it_value["type"]) == kFieldTypeSet.end()) {
121       MS_LOG(ERROR) << "Invalid data, type [" << it_value["type"].dump() << "] is not supported.";
122       return false;
123     }
124 
125     if (it_value.size() == kInt1) {
126       continue;
127     }
128 
129     if (it_value["type"] == "bytes" || it_value["type"] == "string") {
130       MS_LOG(ERROR) << "Invalid data, field [" << it_value.dump() << "] is invalid.";
131       return false;
132     }
133 
134     if (it_value.size() != kInt2) {
135       MS_LOG(ERROR) << "Invalid data, field [" << it_value.dump() << "] is invalid.";
136       return false;
137     }
138 
139     if (!ValidateNumberShape(it_value)) {
140       return false;
141     }
142   }
143 
144   return true;
145 }
146 
operator ==(const mindrecord::Schema & b) const147 bool Schema::operator==(const mindrecord::Schema &b) const {
148   if (this->GetDesc() != b.GetDesc() || this->GetSchema() != b.GetSchema()) {
149     return false;
150   }
151   return true;
152 }
153 }  // namespace mindrecord
154 }  // namespace mindspore
155