• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "fl/server/iteration_metrics.h"
18 #include <string>
19 #include <fstream>
20 #include "debug/common.h"
21 #include "ps/constants.h"
22 
23 namespace mindspore {
24 namespace fl {
25 namespace server {
Initialize()26 bool IterationMetrics::Initialize() {
27   config_ = std::make_unique<ps::core::FileConfiguration>(config_file_path_);
28   MS_EXCEPTION_IF_NULL(config_);
29   if (!config_->Initialize()) {
30     MS_LOG(WARNING) << "Initializing for metrics failed. Config file path " << config_file_path_
31                     << " may be invalid or not exist.";
32     return false;
33   }
34 
35   // Read the metrics file path. If file is not set or not exits, create one.
36   if (!config_->Exists(kMetrics)) {
37     MS_LOG(WARNING) << "Metrics config is not set. Don't write metrics.";
38     return false;
39   } else {
40     std::string value = config_->Get(kMetrics, "");
41     nlohmann::json value_json;
42     try {
43       value_json = nlohmann::json::parse(value);
44     } catch (const std::exception &e) {
45       MS_LOG(EXCEPTION) << "The hyper-parameter data is not in json format.";
46       return false;
47     }
48 
49     // Parse the storage type.
50     uint32_t storage_type = JsonGetKeyWithException<uint32_t>(value_json, ps::kStoreType);
51     if (std::to_string(storage_type) != ps::kFileStorage) {
52       MS_LOG(EXCEPTION) << "Storage type " << storage_type << " is not supported.";
53       return false;
54     }
55 
56     // Parse storage file path.
57     metrics_file_path_ = JsonGetKeyWithException<std::string>(value_json, ps::kStoreFilePath);
58     auto realpath = Common::CreatePrefixPath(metrics_file_path_.c_str());
59     if (!realpath.has_value()) {
60       MS_LOG(EXCEPTION) << "Creating path for " << metrics_file_path_ << " failed.";
61       return false;
62     }
63 
64     metrics_file_.open(realpath.value(), std::ios::ate | std::ios::out);
65     metrics_file_.close();
66   }
67   return true;
68 }
69 
Summarize()70 bool IterationMetrics::Summarize() {
71   metrics_file_.open(metrics_file_path_, std::ios::ate | std::ios::out);
72   if (!metrics_file_.is_open()) {
73     MS_LOG(ERROR) << "The metrics file is not opened.";
74     return false;
75   }
76 
77   js_[kFLName] = fl_name_;
78   js_[kInstanceStatus] = kInstanceStateName.at(instance_state_);
79   js_[kFLIterationNum] = fl_iteration_num_;
80   js_[kCurIteration] = cur_iteration_num_;
81   js_[kJoinedClientNum] = joined_client_num_;
82   js_[kRejectedClientNum] = rejected_client_num_;
83   js_[kMetricsAuc] = accuracy_;
84   js_[kMetricsLoss] = loss_;
85   js_[kIterExecutionTime] = iteration_time_cost_;
86   metrics_file_ << js_ << "\n";
87   (void)metrics_file_.flush();
88   metrics_file_.close();
89   return true;
90 }
91 
Clear()92 bool IterationMetrics::Clear() {
93   if (metrics_file_.is_open()) {
94     MS_LOG(INFO) << "Clear the old metrics file " << metrics_file_path_;
95     metrics_file_.close();
96     metrics_file_.open(metrics_file_path_, std::ios::ate | std::ios::out);
97   }
98   return true;
99 }
100 
set_fl_name(const std::string & fl_name)101 void IterationMetrics::set_fl_name(const std::string &fl_name) { fl_name_ = fl_name; }
102 
set_fl_iteration_num(size_t fl_iteration_num)103 void IterationMetrics::set_fl_iteration_num(size_t fl_iteration_num) { fl_iteration_num_ = fl_iteration_num; }
104 
set_cur_iteration_num(size_t cur_iteration_num)105 void IterationMetrics::set_cur_iteration_num(size_t cur_iteration_num) { cur_iteration_num_ = cur_iteration_num; }
106 
set_instance_state(InstanceState state)107 void IterationMetrics::set_instance_state(InstanceState state) { instance_state_ = state; }
108 
set_loss(float loss)109 void IterationMetrics::set_loss(float loss) { loss_ = loss; }
110 
set_accuracy(float acc)111 void IterationMetrics::set_accuracy(float acc) { accuracy_ = acc; }
112 
set_joined_client_num(size_t joined_client_num)113 void IterationMetrics::set_joined_client_num(size_t joined_client_num) { joined_client_num_ = joined_client_num; }
114 
set_rejected_client_num(size_t rejected_client_num)115 void IterationMetrics::set_rejected_client_num(size_t rejected_client_num) {
116   rejected_client_num_ = rejected_client_num;
117 }
118 
set_iteration_time_cost(uint64_t iteration_time_cost)119 void IterationMetrics::set_iteration_time_cost(uint64_t iteration_time_cost) {
120   iteration_time_cost_ = iteration_time_cost;
121 }
122 }  // namespace server
123 }  // namespace fl
124 }  // namespace mindspore
125