• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "minddata/mindrecord/include/shard_schema.h"
18 #include "utils/ms_utils.h"
19 
20 namespace mindspore {
21 namespace mindrecord {
Build(std::string desc,const json & schema)22 std::shared_ptr<Schema> Schema::Build(std::string desc, const json &schema) {
23   // validate check
24   if (!Validate(schema)) {
25     return nullptr;
26   }
27 
28   std::vector<std::string> blob_fields = PopulateBlobFields(schema);
29   Schema object_schema;
30   object_schema.desc_ = std::move(desc);
31   object_schema.blob_fields_ = std::move(blob_fields);
32   object_schema.schema_ = schema;
33   object_schema.schema_id_ = -1;
34   return std::make_shared<Schema>(object_schema);
35 }
36 
GetDesc() const37 std::string Schema::GetDesc() const { return desc_; }
38 
GetSchema() const39 json Schema::GetSchema() const {
40   json str_schema;
41   str_schema["desc"] = desc_;
42   str_schema["schema"] = schema_;
43   str_schema["blob_fields"] = blob_fields_;
44   return str_schema;
45 }
46 
SetSchemaID(int64_t id)47 void Schema::SetSchemaID(int64_t id) { schema_id_ = id; }
48 
GetSchemaID() const49 int64_t Schema::GetSchemaID() const { return schema_id_; }
50 
GetBlobFields() const51 std::vector<std::string> Schema::GetBlobFields() const { return blob_fields_; }
52 
PopulateBlobFields(json schema)53 std::vector<std::string> Schema::PopulateBlobFields(json schema) {
54   std::vector<std::string> blob_fields;
55   for (json::iterator it = schema.begin(); it != schema.end(); ++it) {
56     json it_value = it.value();
57     if ((it_value.size() == kInt2 && it_value.find("shape") != it_value.end()) || it_value["type"] == "bytes") {
58       blob_fields.emplace_back(it.key());
59     }
60   }
61   return blob_fields;
62 }
63 
ValidateNumberShape(const json & it_value)64 bool Schema::ValidateNumberShape(const json &it_value) {
65   if (it_value.find("shape") == it_value.end()) {
66     MS_LOG(ERROR) << "Invalid schema, 'shape' object can not found in " << it_value.dump()
67                   << ". Please check the input schema.";
68     return false;
69   }
70 
71   auto shape = it_value["shape"];
72   if (!shape.is_array()) {
73     MS_LOG(ERROR) << "Invalid schema, the value of 'shape' should be list format but got: " << it_value["shape"]
74                   << ". Please check the input schema.";
75     return false;
76   }
77 
78   int num_negtive_one = 0;
79   for (const auto &i : shape) {
80     if (i == 0 || i < -1) {
81       MS_LOG(ERROR) << "Invalid schema, the element of 'shape' value should be -1 or greater than 0 but got: " << i
82                     << ". Please check the input schema.";
83       return false;
84     }
85     if (i == -1) {
86       num_negtive_one++;
87     }
88   }
89 
90   if (num_negtive_one > 1) {
91     MS_LOG(ERROR) << "Invalid schema, only 1 variable dimension(-1) allowed in 'shape' value but got: "
92                   << it_value["shape"] << ". Please check the input schema.";
93     return false;
94   }
95 
96   return true;
97 }
98 
Validate(json schema)99 bool Schema::Validate(json schema) {
100   if (schema.empty()) {
101     MS_LOG(ERROR) << "Invalid schema, schema is empty. Please check the input schema.";
102     return false;
103   }
104 
105   for (json::iterator it = schema.begin(); it != schema.end(); ++it) {
106     // make sure schema key name must be composed of '0-9' or 'a-z' or 'A-Z' or '_'
107     if (!ValidateFieldName(it.key())) {
108       MS_LOG(ERROR) << "Invalid schema, field name: " << it.key()
109                     << "is not composed of '0-9' or 'a-z' or 'A-Z' or '_'. Please rename the field name in schema.";
110       return false;
111     }
112 
113     json it_value = it.value();
114     if (it_value.find("type") == it_value.end()) {
115       MS_LOG(ERROR) << "Invalid schema, 'type' object can not found in field " << it_value.dump()
116                     << ". Please add the 'type' object for field in schema.";
117       return false;
118     }
119 
120     if (kFieldTypeSet.find(it_value["type"]) == kFieldTypeSet.end()) {
121       MS_LOG(ERROR) << "Invalid schema, the value of 'type': " << it_value["type"]
122                     << " is not supported.\nPlease modify the value of 'type' to 'int32', 'int64', 'float32', "
123                        "'float64', 'string', 'bytes' in schema.";
124       return false;
125     }
126 
127     if (it_value.size() == kInt1) {
128       continue;
129     }
130 
131     if (it_value["type"] == "bytes" || it_value["type"] == "string") {
132       MS_LOG(ERROR)
133         << "Invalid schema, no other field can be added when the value of 'type' is 'string' or 'types' but got: "
134         << it_value.dump() << ". Please remove other fields in schema.";
135       return false;
136     }
137 
138     if (it_value.size() != kInt2) {
139       MS_LOG(ERROR) << "Invalid schema, the fields should be 'type' or 'type' and 'shape' but got: " << it_value.dump()
140                     << ". Please check the schema.";
141       return false;
142     }
143 
144     if (!ValidateNumberShape(it_value)) {
145       return false;
146     }
147   }
148 
149   return true;
150 }
151 
operator ==(const mindrecord::Schema & b) const152 bool Schema::operator==(const mindrecord::Schema &b) const {
153   if (this->GetDesc() != b.GetDesc() || this->GetSchema() != b.GetSchema()) {
154     return false;
155   }
156   return true;
157 }
158 }  // namespace mindrecord
159 }  // namespace mindspore
160