1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "minddata/mindrecord/include/shard_schema.h"
18 #include "utils/ms_utils.h"
19
20 namespace mindspore {
21 namespace mindrecord {
Build(std::string desc,const json & schema)22 std::shared_ptr<Schema> Schema::Build(std::string desc, const json &schema) {
23 // validate check
24 if (!Validate(schema)) {
25 return nullptr;
26 }
27
28 std::vector<std::string> blob_fields = PopulateBlobFields(schema);
29 Schema object_schema;
30 object_schema.desc_ = std::move(desc);
31 object_schema.blob_fields_ = std::move(blob_fields);
32 object_schema.schema_ = schema;
33 object_schema.schema_id_ = -1;
34 return std::make_shared<Schema>(object_schema);
35 }
36
GetDesc() const37 std::string Schema::GetDesc() const { return desc_; }
38
GetSchema() const39 json Schema::GetSchema() const {
40 json str_schema;
41 str_schema["desc"] = desc_;
42 str_schema["schema"] = schema_;
43 str_schema["blob_fields"] = blob_fields_;
44 return str_schema;
45 }
46
SetSchemaID(int64_t id)47 void Schema::SetSchemaID(int64_t id) { schema_id_ = id; }
48
GetSchemaID() const49 int64_t Schema::GetSchemaID() const { return schema_id_; }
50
GetBlobFields() const51 std::vector<std::string> Schema::GetBlobFields() const { return blob_fields_; }
52
PopulateBlobFields(json schema)53 std::vector<std::string> Schema::PopulateBlobFields(json schema) {
54 std::vector<std::string> blob_fields;
55 for (json::iterator it = schema.begin(); it != schema.end(); ++it) {
56 json it_value = it.value();
57 if ((it_value.size() == kInt2 && it_value.find("shape") != it_value.end()) || it_value["type"] == "bytes") {
58 blob_fields.emplace_back(it.key());
59 }
60 }
61 return blob_fields;
62 }
63
ValidateNumberShape(const json & it_value)64 bool Schema::ValidateNumberShape(const json &it_value) {
65 if (it_value.find("shape") == it_value.end()) {
66 MS_LOG(ERROR) << "Invalid schema, 'shape' object can not found in " << it_value.dump()
67 << ". Please check the input schema.";
68 return false;
69 }
70
71 auto shape = it_value["shape"];
72 if (!shape.is_array()) {
73 MS_LOG(ERROR) << "Invalid schema, the value of 'shape' should be list format but got: " << it_value["shape"]
74 << ". Please check the input schema.";
75 return false;
76 }
77
78 int num_negtive_one = 0;
79 for (const auto &i : shape) {
80 if (i == 0 || i < -1) {
81 MS_LOG(ERROR) << "Invalid schema, the element of 'shape' value should be -1 or greater than 0 but got: " << i
82 << ". Please check the input schema.";
83 return false;
84 }
85 if (i == -1) {
86 num_negtive_one++;
87 }
88 }
89
90 if (num_negtive_one > 1) {
91 MS_LOG(ERROR) << "Invalid schema, only 1 variable dimension(-1) allowed in 'shape' value but got: "
92 << it_value["shape"] << ". Please check the input schema.";
93 return false;
94 }
95
96 return true;
97 }
98
Validate(json schema)99 bool Schema::Validate(json schema) {
100 if (schema.empty()) {
101 MS_LOG(ERROR) << "Invalid schema, schema is empty. Please check the input schema.";
102 return false;
103 }
104
105 for (json::iterator it = schema.begin(); it != schema.end(); ++it) {
106 // make sure schema key name must be composed of '0-9' or 'a-z' or 'A-Z' or '_'
107 if (!ValidateFieldName(it.key())) {
108 MS_LOG(ERROR) << "Invalid schema, field name: " << it.key()
109 << "is not composed of '0-9' or 'a-z' or 'A-Z' or '_'. Please rename the field name in schema.";
110 return false;
111 }
112
113 json it_value = it.value();
114 if (it_value.find("type") == it_value.end()) {
115 MS_LOG(ERROR) << "Invalid schema, 'type' object can not found in field " << it_value.dump()
116 << ". Please add the 'type' object for field in schema.";
117 return false;
118 }
119
120 if (kFieldTypeSet.find(it_value["type"]) == kFieldTypeSet.end()) {
121 MS_LOG(ERROR) << "Invalid schema, the value of 'type': " << it_value["type"]
122 << " is not supported.\nPlease modify the value of 'type' to 'int32', 'int64', 'float32', "
123 "'float64', 'string', 'bytes' in schema.";
124 return false;
125 }
126
127 if (it_value.size() == kInt1) {
128 continue;
129 }
130
131 if (it_value["type"] == "bytes" || it_value["type"] == "string") {
132 MS_LOG(ERROR)
133 << "Invalid schema, no other field can be added when the value of 'type' is 'string' or 'types' but got: "
134 << it_value.dump() << ". Please remove other fields in schema.";
135 return false;
136 }
137
138 if (it_value.size() != kInt2) {
139 MS_LOG(ERROR) << "Invalid schema, the fields should be 'type' or 'type' and 'shape' but got: " << it_value.dump()
140 << ". Please check the schema.";
141 return false;
142 }
143
144 if (!ValidateNumberShape(it_value)) {
145 return false;
146 }
147 }
148
149 return true;
150 }
151
operator ==(const mindrecord::Schema & b) const152 bool Schema::operator==(const mindrecord::Schema &b) const {
153 if (this->GetDesc() != b.GetDesc() || this->GetSchema() != b.GetSchema()) {
154 return false;
155 }
156 return true;
157 }
158 } // namespace mindrecord
159 } // namespace mindspore
160