1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "minddata/mindrecord/include/shard_statistics.h"
18 #include "pybind11/pybind11.h"
19
20 namespace mindspore {
21 namespace mindrecord {
Build(std::string desc,const json & statistics)22 std::shared_ptr<Statistics> Statistics::Build(std::string desc, const json &statistics) {
23 // validate check
24 if (!Validate(statistics)) {
25 return nullptr;
26 }
27 Statistics object_statistics;
28 object_statistics.desc_ = std::move(desc);
29 object_statistics.statistics_ = statistics;
30 object_statistics.statistics_id_ = -1;
31 return std::make_shared<Statistics>(object_statistics);
32 }
33
GetDesc() const34 std::string Statistics::GetDesc() const { return desc_; }
35
GetStatistics() const36 json Statistics::GetStatistics() const {
37 json str_statistics;
38 str_statistics["desc"] = desc_;
39 str_statistics["statistics"] = statistics_;
40 return str_statistics;
41 }
42
SetStatisticsID(int64_t id)43 void Statistics::SetStatisticsID(int64_t id) { statistics_id_ = id; }
44
GetStatisticsID() const45 int64_t Statistics::GetStatisticsID() const { return statistics_id_; }
46
Validate(const json & statistics)47 bool Statistics::Validate(const json &statistics) {
48 if (statistics.size() != kInt1) {
49 MS_LOG(ERROR) << "Invalid data, 'statistics' is empty.";
50 return false;
51 }
52 if (statistics.find("level") == statistics.end()) {
53 MS_LOG(ERROR) << "Invalid data, 'level' object can not found in statistic";
54 return false;
55 }
56 return LevelRecursive(statistics["level"]);
57 }
58
LevelRecursive(json level)59 bool Statistics::LevelRecursive(json level) {
60 bool ini = true;
61 for (json::iterator it = level.begin(); it != level.end(); ++it) {
62 json a = it.value();
63 if (a.size() == kInt2) {
64 if ((a.find("key") == a.end()) || (a.find("count") == a.end())) {
65 MS_LOG(ERROR) << "Invalid data, the node field is 2, but 'key'/'count' object does not existed";
66 return false;
67 }
68 } else if (a.size() == kInt3) {
69 if ((a.find("key") == a.end()) || (a.find("count") == a.end()) || a.find("level") == a.end()) {
70 MS_LOG(ERROR) << "Invalid data, the node field is 3, but 'key'/'count'/'level' object does not existed";
71 return false;
72 } else {
73 ini = LevelRecursive(a.at("level"));
74 }
75 } else {
76 MS_LOG(ERROR) << "Invalid data, the node field is not equal to 2 or 3";
77 return false;
78 }
79 }
80 return ini;
81 }
82
operator ==(const Statistics & b) const83 bool Statistics::operator==(const Statistics &b) const {
84 if (this->GetStatistics() != b.GetStatistics()) {
85 return false;
86 }
87 return true;
88 }
89 } // namespace mindrecord
90 } // namespace mindspore
91