• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_TENSOR_SUMMARY_H
17 #define MINDSPORE_TENSOR_SUMMARY_H
18 
19 #include <vector>
20 #include <tuple>
21 #include <memory>
22 #include <string>
23 
24 #include "utils/hash_map.h"
25 #include "debug/debug_services.h"
26 
27 namespace mindspore {
28 class RangeCountCalculator {
29  public:
30   RangeCountCalculator();
31   ~RangeCountCalculator() = default;
32   void ProcessElement(double element);
33   double GetPercentInRange() const;
set_range_start_inclusive(double value)34   void set_range_start_inclusive(double value) { range_start_inclusive = value; }
set_range_end_inclusive(double value)35   void set_range_end_inclusive(double value) { range_end_inclusive = value; }
36 
37  private:
38   double range_start_inclusive;
39   double range_end_inclusive;
40   int count;
41   int total;
42 };
43 
44 class AllCloseCalculator {
45  public:
46   AllCloseCalculator();
47   ~AllCloseCalculator() = default;
48   void ProcessElement(double current, double previous);
49   bool IsAllClose() const;
set_atol(double value)50   void set_atol(double value) { atol = value; }
set_rtol(double value)51   void set_rtol(double value) { rtol = value; }
52 
53  private:
54   double atol;
55   double rtol;
56   bool result;
57 };
58 
59 class MeanCalculator {
60  public:
61   MeanCalculator();
62   ~MeanCalculator() = default;
63   void ProcessElement(double value);
64   double GetMean() const;
65 
66  protected:
67   double mean;
68   int count;
69 };
70 
71 class VarianceAndMeanCalculator {
72  public:
73   VarianceAndMeanCalculator();
74   ~VarianceAndMeanCalculator() = default;
75   void ProcessElement(double value);
76   double GetStandardDeviation() const;
77   double GetVariance() const;
78   double GetMean() const;
79 
80  private:
81   double mean;
82   int count;
83   double m2;
84 };
85 
86 class L2Calculator {
87  public:
L2Calculator()88   L2Calculator() : squre_sum(0.0) {}
89   ~L2Calculator() = default;
90   void ProcessElement(double value);
91   void ProcessElement(const L2Calculator &other);
92   double GetL2Value() const;
93 
94  private:
95   // save (x^2 + y^2)/y^2, when y > x, to avoid itermidiate value overflow
96   // the true l2 value should be sqrt(squre_sum_div_max_ * max_value_^2)
97   double squre_sum;
98 };
99 
100 class ITensorSummary {
101  public:
102   enum WatchpointPos { eHitPos = 0, eErrorCodePos = 1, eParamListPos = 2 };
103   enum ErrorCode {
104     NAN_TENSOR = 0,
105     INF_TENSOR = 2,
106     NULL_PREV_TENSOR = 4,
107     OUT_OF_MEMORY = 8,
108     HISTORY_NOT_FOUND = 16,
109     NO_VALUE = 32
110   };
111   virtual ~ITensorSummary() = default;
112   virtual void SummarizeTensor(const std::vector<DebugServices::watchpoint_t> &wps) = 0;
113   virtual std::tuple<bool, int32_t, std::vector<DebugServices::parameter_t>> IsWatchpointHit(
114     DebugServices::watchpoint_t) = 0;
115   virtual void TensorStatistics(DbgDataType dtype_value) = 0;
116   virtual const bool is_bool() const = 0;
117   virtual const double max_value() const = 0;
118   virtual const double min_value() const = 0;
119   virtual const double avg_value() const = 0;
120   virtual const double l2_value() const = 0;
121 
122   virtual const uint64_t count() const = 0;
123   virtual const uint64_t neg_zero_count() const = 0;
124   virtual const uint64_t pos_zero_count() const = 0;
125   virtual const uint64_t nan_count() const = 0;
126   virtual const uint64_t neg_inf_count() const = 0;
127   virtual const uint64_t pos_inf_count() const = 0;
128   virtual const uint64_t zero_count() const = 0;
129 };
130 
131 template <typename T>
132 class TensorSummary : public ITensorSummary {
133  public:
134   TensorSummary() = default;
135   ~TensorSummary() override = default;
136   TensorSummary(const void *current_tensor_ptr, const void *const previous_tensor_ptr, uint64_t num_elements,
137                 uint64_t prev_num_elements);
138   void SummarizeTensor(const std::vector<DebugServices::watchpoint_t> &wps) override;
139   // returns hit, error_code, parameter_list
140   std::tuple<bool, int, std::vector<DebugServices::parameter_t>> IsWatchpointHit(
141     DebugServices::watchpoint_t wp) override;
142   void TensorStatistics(DbgDataType dtype_value) override;
is_bool()143   const bool is_bool() const override { return is_bool_; }
max_value()144   const double max_value() const override { return max_; }
min_value()145   const double min_value() const override { return min_; }
avg_value()146   const double avg_value() const override { return avg_; }
count()147   const uint64_t count() const override { return num_elements_; }
neg_zero_count()148   const uint64_t neg_zero_count() const override { return neg_zero_count_; }
pos_zero_count()149   const uint64_t pos_zero_count() const override { return pos_zero_count_; }
nan_count()150   const uint64_t nan_count() const override { return nan_count_; }
neg_inf_count()151   const uint64_t neg_inf_count() const override { return neg_inf_count_; }
pos_inf_count()152   const uint64_t pos_inf_count() const override { return pos_inf_count_; }
zero_count()153   const uint64_t zero_count() const override { return zero_count_; }
l2_value()154   const double l2_value() const override { return l2_calc_.GetL2Value(); }
155 
156  private:
157   const T *current_tensor_ptr_;
158   const T *prev_tensor_ptr_;
159   uint64_t num_elements_;
160   uint64_t prev_num_elements_;
161   double min_;
162   double max_;
163   double avg_;
164   bool is_bool_;
165   uint64_t neg_zero_count_;
166   uint64_t pos_zero_count_;
167   uint64_t pos_inf_count_;
168   uint64_t neg_inf_count_;
169   uint64_t inf_count_;
170   uint64_t nan_count_;
171   uint64_t zero_count_;
172   double epsilon_;
173   bool mean_sd_cal_enabled_;
174   VarianceAndMeanCalculator current_mean_variance_;
175   L2Calculator l2_calc_;
176   mindspore::HashMap<std::string, std::unique_ptr<MeanCalculator>> means_;
177   mindspore::HashMap<uint32_t, std::unique_ptr<AllCloseCalculator>> all_close_;
178   mindspore::HashMap<uint32_t, std::unique_ptr<RangeCountCalculator>> range_counts_;
179   double_t StatLookup(const DebugServices::watchpoint_t &wp) const;
180   double_t StatLookup(const std::string &parameter_name, const DebugServices::watchpoint_t &wp);
181   double_t GetZeroValPercent() const;
182   void TensorStatisticsSingleThread();
183   void InitCalculators(const std::vector<DebugServices::watchpoint_t> &);
184 };
185 }  // namespace mindspore
186 #endif  // MINDSPORE_TENSOR_SUMMARY_H
187