• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "include/common/utils/dynamic_obfuscation/dynamic_obfuscation.h"
17 #include <algorithm>
18 #include <functional>
19 #include <map>
20 #include <memory>
21 #include <random>
22 #include "ops/conv_pool_op_name.h"
23 #include "ops/math_op_name.h"
24 #include "ops/other_ops.h"
25 #include "ops/comparison_ops.h"
26 #include "ops/array_ops.h"
27 #include "ops/auto_generate/gen_ops_primitive.h"
28 #include "ops/framework_ops.h"
29 #include "include/common/debug/anf_ir_dump.h"
30 #include "include/common/utils/dynamic_obfuscation/registry_opaque_predicate.h"
31 #include "include/common/utils/utils.h"
32 #include "ir/anf.h"
33 #include "ir/tensor.h"
34 #include "utils/info.h"
35 
36 namespace mindspore {
37 namespace {
AddObfuscatedParam(FuncGraphPtr func_graph)38 ParameterPtr AddObfuscatedParam(FuncGraphPtr func_graph) {
39   auto params = func_graph->parameters();
40   auto add_param = std::make_shared<Parameter>(func_graph);
41   std::vector<AnfNodePtr> new_para_list(params.begin(), params.begin() + params.size() - func_graph->fv_param_count());
42   (void)new_para_list.emplace_back(add_param);
43   (void)new_para_list.insert(new_para_list.cend(), params.begin() + params.size() - func_graph->fv_param_count(),
44                              params.end());
45   func_graph->set_parameters(new_para_list);
46   return add_param;
47 }
48 }  // namespace
49 using Tensor = mindspore::tensor::Tensor;
50 using mindspore::abstract::AbstractTensor;
51 using mindspore::abstract::AbstractTensorPtr;
52 using mindspore::abstract::AbstractTuple;
53 using mindspore::abstract::AbstractTuplePtr;
54 
55 constexpr int keyExpandRate = 10;  // total node need for a switch graph
56 constexpr int kWeightIndex = 2;
57 constexpr int kSwitchInputsNum = 2;
58 constexpr int kNodeWithWeightInputsNum = 3;
59 
get_node_shape(const AnfNodePtr & input_node)60 ShapeVector get_node_shape(const AnfNodePtr &input_node) {
61   if (input_node == nullptr) {
62     MS_LOG(ERROR) << "Input node is nullptr, get shape failed!";
63     return {};
64   }
65   AbstractBasePtr input_abstract = input_node->abstract();
66   if (input_abstract == nullptr) {
67     MS_LOG(ERROR) << "The abstract of input_node is nullptr, get shape failed!";
68     return {};
69   }
70   AbstractTensorPtr input_abstract_tensor = input_abstract->cast<mindspore::abstract::AbstractTensorPtr>();
71   MS_EXCEPTION_IF_NULL(input_abstract_tensor);
72   mindspore::abstract::ShapePtr shape_ptr = input_abstract_tensor->shape();
73   if (shape_ptr == nullptr) {
74     return {};
75   }
76   return shape_ptr->shape();
77 }
78 
get_node_dtype(const AnfNodePtr & input_node)79 TypeId get_node_dtype(const AnfNodePtr &input_node) {
80   if (input_node == nullptr) {
81     MS_LOG(ERROR) << "Input node is nullptr, get dtype failed!";
82     return {};
83   }
84   AbstractBasePtr input_abstract = input_node->abstract();
85   if (input_abstract == nullptr) {
86     MS_LOG(ERROR) << "The abstract of input_node is nullptr, get dtype failed!";
87     return {};
88   }
89   AbstractTensorPtr input_abstract_tensor = input_abstract->cast<mindspore::abstract::AbstractTensorPtr>();
90   MS_EXCEPTION_IF_NULL(input_abstract_tensor);
91   AbstractBasePtr node_element = input_abstract_tensor->element();
92   mindspore::abstract::AbstractScalarPtr node_element_abs =
93     node_element->cast<mindspore::abstract::AbstractScalarPtr>();
94   MS_EXCEPTION_IF_NULL(node_element_abs);
95   TypeId data_type = node_element_abs->BuildType()->type_id();
96   return data_type;
97 }
98 
name_split(const std::string & node_name_,const std::string & split_sign)99 std::vector<std::string> name_split(const std::string &node_name_, const std::string &split_sign) {
100   std::string node_name = node_name_;
101   node_name += split_sign;
102   unsigned int name_len = node_name.size();
103   std::string::size_type split_pos;
104   std::vector<std::string> res;
105   for (unsigned int i = 0; i < name_len; i++) {
106     split_pos = node_name.find(split_sign, i);
107     if (split_pos < name_len) {
108       std::string sub_str = node_name.substr(i, split_pos - i);
109       res.push_back(sub_str);
110       i = split_pos + SizeToUint(split_sign.size()) - 1;
111     }
112   }
113   return res;
114 }
115 
get_node_prim_name(const AnfNodePtr & node)116 std::string get_node_prim_name(const AnfNodePtr &node) {
117   if (node == nullptr) {
118     MS_LOG(ERROR) << "Input node is nullptr, get name failed!";
119     return "";
120   }
121   PrimitivePtr node_prim = GetCNodePrimitive(node);
122   if (node_prim == nullptr) {
123     MS_LOG(DEBUG) << "The primitive of node " << node->fullname_with_scope() << " is nullptr!";
124     return "";
125   }
126   return node_prim->ToString();
127 }
128 
get_op_num(const AnfNodePtr & node)129 int get_op_num(const AnfNodePtr &node) {
130   if (node == nullptr) {
131     MS_LOG(ERROR) << "Input node is nullptr, get name failed!";
132     return 0;
133   }
134   std::string node_name = node->fullname_with_scope();
135   std::vector<string> split_words = name_split(node_name, "op");
136   if (split_words.empty()) {
137     MS_LOG(WARNING) << "Input node name is empty.";
138     return 0;
139   }
140   std::string num = split_words[split_words.size() - 1];
141   return std::stoi(num);
142 }
143 
get_node_param(const FuncGraphPtr func_graph,const CNodePtr & node)144 ParameterPtr get_node_param(const FuncGraphPtr func_graph, const CNodePtr &node) {
145   if (node == nullptr) {
146     MS_LOG(ERROR) << "Node is nullptr, get param failed!";
147     return nullptr;
148   }
149   if (func_graph == nullptr) {
150     MS_LOG(ERROR) << "FuncGraph is nullptr, get param failed!";
151     return nullptr;
152   }
153   std::string parameter_name = "";
154   for (auto &weak_input : node->weak_inputs()) {
155     auto input = weak_input.lock();
156     MS_EXCEPTION_IF_NULL(input);
157     std::string op_name = get_node_prim_name(input);
158     MS_LOG(INFO) << "op_name is: " << op_name;
159     if (op_name == "Load") {
160       for (auto weak_param : input->cast<mindspore::CNodePtr>()->weak_inputs()) {
161         auto param = weak_param.lock();
162         MS_EXCEPTION_IF_NULL(param);
163         if (param->fullname_with_scope().find("weight") != std::string::npos) {
164           parameter_name = param->fullname_with_scope();
165           break;
166         }
167       }
168     }
169   }
170   for (auto param : func_graph->parameters()) {
171     auto param_node = param->cast<mindspore::ParameterPtr>();
172     if (param_node == nullptr) {
173       MS_LOG(ERROR) << "Param node is nullptr.";
174       return nullptr;
175     }
176     if (param->fullname_with_scope() == parameter_name) {
177       return param_node;
178     }
179   }
180   return nullptr;
181 }
182 
build_tuple_value_node(const std::vector<int64_t> & values)183 ValueNodePtr build_tuple_value_node(const std::vector<int64_t> &values) {
184   mindspore::ValueNodePtr v_node = std::make_shared<mindspore::ValueNode>(MakeValue(values));
185   AbstractBasePtrList abs_list;
186   (void)std::transform(values.cbegin(), values.cend(), std::back_inserter(abs_list), [](const int64_t &item) {
187     return std::make_shared<mindspore::abstract::AbstractScalar>(int64_t(item));
188   });
189   auto abs_tuple = std::make_shared<mindspore::abstract::AbstractTuple>(abs_list);
190   v_node->set_abstract(abs_tuple);
191   return v_node;
192 }
193 
make_int_node(const FuncGraphPtr func_graph,int int_value)194 ValueNodePtr make_int_node(const FuncGraphPtr func_graph, int int_value) {
195   ShapeVector int_shape{1};
196   tensor::TensorPtr int_tensor = std::make_shared<Tensor>(mindspore::kNumberTypeInt32, int_shape);
197   int *tensor_data = reinterpret_cast<int *>(int_tensor->data_c());
198   for (int i = 0; i < int_tensor->data().size(); i++) {
199     tensor_data[i] = int_value;
200   }
201   mindspore::ValueNodePtr int_tensor_node = std::make_shared<mindspore::ValueNode>(int_tensor);
202   int_tensor_node->set_abstract(int_tensor->ToAbstract());
203   func_graph->AddValueNode(int_tensor_node);
204   return int_tensor_node;
205 }
206 
make_weight_tensor(TypeId type_id,ShapeVector shape)207 tensor::TensorPtr make_weight_tensor(TypeId type_id, ShapeVector shape) {
208   tensor::TensorPtr weight_tensor = std::make_shared<Tensor>(type_id, shape);
209   std::default_random_engine generator;
210   int max_count = 10000;
211   int tensor_size = SizeToInt(weight_tensor->data().size());
212   if (type_id == kNumberTypeFloat64) {
213     const double mean_64 = 0;
214     const double stddev_64 = 1;
215     std::normal_distribution<double> dist_64(mean_64, stddev_64);
216     double *float_64_data = reinterpret_cast<double *>(weight_tensor->data_c());
217     for (int i = 0; i < std::min(tensor_size, max_count); i++) {
218       double random_float_64 = dist_64(generator);
219       if (random_float_64 > 0) {
220         float_64_data[i] = random_float_64;
221       }
222     }
223   } else {
224     MS_LOG(DEBUG) << "Type id is: " << type_id << ", weights will be float_32 format.";
225     const float mean = 0;
226     const float stddev = 1;
227     std::normal_distribution<float> dist_32(mean, stddev);
228     float *float_32_data = reinterpret_cast<float *>(weight_tensor->data_c());
229     for (int i = 0; i < std::min(tensor_size, max_count); i++) {
230       float random_float_32 = dist_32(generator);
231       if (random_float_32 > 0) {
232         float_32_data[i] = random_float_32;
233       }
234     }
235   }
236   return weight_tensor;
237 }
238 
CheckIfObfuscated(const FuncGraphPtr & func_graph)239 bool CheckIfObfuscated(const FuncGraphPtr &func_graph) {
240   MS_EXCEPTION_IF_NULL(func_graph);
241   auto mgr = Manage(func_graph);
242   MS_EXCEPTION_IF_NULL(mgr);
243   auto all_nodes = mgr->all_nodes();
244   for (AnfNodePtr node : all_nodes) {
245     MS_EXCEPTION_IF_NULL(node);
246     std::string node_name = node->fullname_with_scope();
247     if (node_name.find("Switch") != std::string::npos) {
248       return true;
249     }
250   }
251   return false;
252 }
253 
ObfuscateMindIR(const FuncGraphPtr & func_graph)254 FuncGraphPtr DynamicObfuscator::ObfuscateMindIR(const FuncGraphPtr &func_graph) {
255   MS_LOG(INFO) << "Start obfuscation.";
256   MS_EXCEPTION_IF_NULL(func_graph);
257   if (CheckIfObfuscated(func_graph)) {
258     MS_EXCEPTION(ValueError) << "The input model has been onfuscated, do not obfuscate it again.";
259   }
260   auto mgr = Manage(func_graph);
261   MS_EXCEPTION_IF_NULL(mgr);
262   auto all_nodes = mgr->all_nodes();
263   for (auto item : all_nodes) {
264     auto abs = item->abstract();
265     if (abs != nullptr) {
266       item->set_abstract(abs->Broaden());
267     }
268   }
269   int node_nums = SizeToLong(all_nodes.size());
270   MS_LOG(INFO) << "Total node num: " << node_nums;
271 
272   // do subgraph fake-branch obfuscation
273   SubGraphFakeBranch(func_graph);
274 
275   if (subgraph_obf_num_ == 0) {
276     MS_LOG(WARNING)
277       << "The model has not been obfuscated, which means obf_random_seed or customized_func is not need to set.";
278   }
279   return func_graph;
280 }
281 
ObfuscateOpType(const AnfNodePtr & node)282 std::string DynamicObfuscator::ObfuscateOpType(const AnfNodePtr &node) {
283   if (node == nullptr) {
284     MS_LOG(ERROR) << "Input node is nullptr, get name failed!";
285     return "";
286   }
287   if (node->isa<CNode>()) {
288     MS_LOG(INFO) << "The node_name is: " << node->fullname_with_scope();
289     std::string op_name = get_node_prim_name(node);
290     std::vector<std::string> target_op_list;
291     target_op_list.insert(target_op_list.end(), single_input_target_op_.begin(), single_input_target_op_.end());
292     target_op_list.insert(target_op_list.end(), single_input_with_weight_target_op_.begin(),
293                           single_input_with_weight_target_op_.end());
294 
295     auto found = std::find_if(target_op_list.cbegin(), target_op_list.cend(),
296                               [&](const auto &target_name) { return op_name == target_name; });
297     if (found != target_op_list.cend()) {
298       return *found;
299     }
300   }
301   return "";
302 }
303 
ObfuscateOpCase(const std::string obf_type)304 ObfCase DynamicObfuscator::ObfuscateOpCase(const std::string obf_type) {
305   if (obf_type.empty()) {
306     MS_LOG(ERROR) << "Obf_type is empty string.";
307     return ObfCase::NotObfNode;
308   }
309   auto name_equal = [&obf_type](const std::string &s) { return s == obf_type; };
310   if (std::any_of(single_input_target_op_.begin(), single_input_target_op_.end(), name_equal)) {
311     return ObfCase::OneInputNoWeightNode;
312   } else if (std::any_of(single_input_with_weight_target_op_.begin(), single_input_with_weight_target_op_.end(),
313                          name_equal)) {
314     return ObfCase::OneInputWithWeightNode;
315   } else {
316     return ObfCase::NotObfNode;
317   }
318 }
319 
RandomSeedModeControl(const FuncGraphPtr func_graph)320 CNodePtr DynamicObfuscator::RandomSeedModeControl(const FuncGraphPtr func_graph) {
321   ShapeVector y_shape{1};
322   tensor::TensorPtr y_tensor = std::make_shared<Tensor>(mindspore::kNumberTypeInt32, y_shape);
323   if (!has_build_appended_input) {
324     MS_LOG(INFO) << "Build parameter y_append.";
325     auto y_append = AddObfuscatedParam(func_graph);
326     y_append->set_name("y_append");
327     y_append->set_abstract(y_tensor->ToAbstract());
328     has_build_appended_input = true;
329   }
330   auto y_append = func_graph->GetParameterByName("y_append");
331 
332   if (used_control_node_ == 0) {
333     // make equal function node
334     ValueNodePtr equal_v_node = std::make_shared<mindspore::ValueNode>(mindspore::prim::kPrimEqual);
335     func_graph->AddValueNode(equal_v_node);
336     ValueNodePtr equal_compa_node = make_int_node(func_graph, branch_control_input_);
337     CNodePtr equal_c_node = func_graph->NewCNode({equal_v_node, y_append, equal_compa_node});
338     if (equal_c_node == nullptr) {
339       MS_LOG(ERROR) << "equal_c_node is nullptr.";
340       return nullptr;
341     }
342     tensor::TensorPtr equal_tensor = std::make_shared<Tensor>(mindspore::kNumberTypeBool, y_shape);
343     equal_c_node->set_abstract(equal_tensor->ToAbstract());
344     func_graph->AddNode(equal_c_node);
345     used_control_node_ += 1;
346     switch_branch_ = true;
347     return equal_c_node;
348   }
349   // make greater function node
350   int comparison_int = rand();
351   ValueNodePtr greater_v_node = std::make_shared<mindspore::ValueNode>(mindspore::prim::kPrimGreater);
352   func_graph->AddValueNode(greater_v_node);
353   ValueNodePtr greater_compa_node = make_int_node(func_graph, comparison_int);
354   CNodePtr greater_c_node = func_graph->NewCNode({greater_v_node, y_append, greater_compa_node});
355   if (greater_c_node == nullptr) {
356     MS_LOG(ERROR) << "greater_c_node is nullptr.";
357     return nullptr;
358   }
359   tensor::TensorPtr greater_tensor = std::make_shared<Tensor>(mindspore::kNumberTypeBool, y_shape);
360   greater_c_node->set_abstract(greater_tensor->ToAbstract());
361   func_graph->AddNode(greater_c_node);
362   used_control_node_ += 1;
363   switch_branch_ = branch_control_input_ > comparison_int;
364   return greater_c_node;
365 }
366 
CreateScalarValue(const FuncGraphPtr & func_graph,int64_t value)367 ValueNodePtr CreateScalarValue(const FuncGraphPtr &func_graph, int64_t value) {
368   auto scalar_value = MakeValue(value);
369   auto scalar_node = NewValueNode(scalar_value);
370   scalar_node->set_abstract(scalar_value->ToAbstract());
371   func_graph->AddValueNode(scalar_node);
372   return scalar_node;
373 }
374 
add_stride_slice_node(FuncGraphPtr func_graph,ShapeVector begin_vector,ShapeVector stride_vector,ShapeVector end_vector,int end_mask,int begin_mask,mindspore::CNodePtr prev_node)375 mindspore::CNodePtr add_stride_slice_node(FuncGraphPtr func_graph, ShapeVector begin_vector, ShapeVector stride_vector,
376                                           ShapeVector end_vector, int end_mask, int begin_mask,
377                                           mindspore::CNodePtr prev_node) {
378   mindspore::ValueNodePtr begin_v_node = build_tuple_value_node(begin_vector);
379   mindspore::ValueNodePtr stride_v_node = build_tuple_value_node(stride_vector);
380   mindspore::ValueNodePtr end_v_node = build_tuple_value_node(end_vector);
381   auto begin_mask_node = CreateScalarValue(func_graph, begin_mask);
382   MS_EXCEPTION_IF_NULL(begin_mask_node);
383   auto end_mask_node = CreateScalarValue(func_graph, end_mask);
384   MS_EXCEPTION_IF_NULL(end_mask_node);
385   auto ellipsis_mask_node = CreateScalarValue(func_graph, int64_t(0));
386   MS_EXCEPTION_IF_NULL(ellipsis_mask_node);
387   auto new_axis_mask_node = CreateScalarValue(func_graph, int64_t(0));
388   MS_EXCEPTION_IF_NULL(new_axis_mask_node);
389   auto shrink_axis_mask_node = CreateScalarValue(func_graph, int64_t(1));
390   MS_EXCEPTION_IF_NULL(shrink_axis_mask_node);
391   func_graph->AddValueNode(begin_v_node);
392   func_graph->AddValueNode(stride_v_node);
393   func_graph->AddValueNode(end_v_node);
394   mindspore::PrimitivePtr slice_prim = mindspore::prim::kPrimStridedSlice;
395   slice_prim->set_attr("is_load", MakeValue(true));
396   mindspore::ValueNodePtr slice_v_node = std::make_shared<mindspore::ValueNode>(slice_prim);
397   func_graph->AddValueNode(slice_v_node);
398   mindspore::CNodePtr slice_c_node =
399     func_graph->NewCNode({slice_v_node, prev_node, begin_v_node, end_v_node, stride_v_node, begin_mask_node,
400                           end_mask_node, ellipsis_mask_node, new_axis_mask_node, shrink_axis_mask_node});
401   return slice_c_node;
402 }
403 
CustomOpModeControl(const FuncGraphPtr func_graph,const AnfNodePtr & prev_node) const404 CNodePtr DynamicObfuscator::CustomOpModeControl(const FuncGraphPtr func_graph, const AnfNodePtr &prev_node) const {
405   mindspore::PrimitivePtr reshape_prim = mindspore::prim::kPrimReshape;
406   reshape_prim->set_attr("is_load", MakeValue(true));
407   mindspore::ValueNodePtr reshape_v_node = std::make_shared<mindspore::ValueNode>(reshape_prim);
408   func_graph->AddValueNode(reshape_v_node);
409   ShapeVector prev_node_shape = get_node_shape(prev_node);
410   int shape_multiply = std::accumulate(prev_node_shape.cbegin(), prev_node_shape.cend(), 1, std::multiplies<int>());
411   MS_LOG(INFO) << "The shape_multiply is: " << shape_multiply;
412 
413   ShapeVector flat_shape{1, shape_multiply};
414   mindspore::ValueNodePtr shape_v_node = std::make_shared<mindspore::ValueNode>(MakeValue(flat_shape));
415   func_graph->AddValueNode(shape_v_node);
416   mindspore::CNodePtr reshape_c_node = func_graph->NewCNode({reshape_v_node, prev_node, shape_v_node});
417   TypeId data_type = get_node_dtype(prev_node);
418   auto reshape_abstract = std::make_shared<Tensor>(data_type, flat_shape)->ToAbstract();
419   reshape_c_node->set_abstract(reshape_abstract);
420   func_graph->AddNode(reshape_c_node);
421 
422   // the first stride_slice x[0]
423   ShapeVector begin_1{0, 0};
424   ShapeVector stride_1{1, 1};
425   mindspore::CNodePtr slice_c_node_1 =
426     add_stride_slice_node(func_graph, begin_1, stride_1, flat_shape, 2, 2, reshape_c_node);
427   ShapeVector slice_1_shape{shape_multiply};
428   slice_c_node_1->set_abstract(std::make_shared<Tensor>(data_type, slice_1_shape)->ToAbstract());
429   func_graph->AddNode(slice_c_node_1);
430 
431   // the first stride_slice x[0][0]
432   ShapeVector begin_2{0};
433   ShapeVector end_2{1};
434   ShapeVector stride_2{1};
435   mindspore::CNodePtr slice_c_node_2 =
436     add_stride_slice_node(func_graph, begin_2, stride_2, stride_2, 0, 0, slice_c_node_1);
437   ShapeVector slice_2_shape{1};
438   slice_c_node_2->set_abstract(std::make_shared<Tensor>(data_type, slice_2_shape)->ToAbstract());
439   func_graph->AddNode(slice_c_node_2);
440 
441   // the second stride_slice x[0][1]
442   ShapeVector begin_3{1};
443   ShapeVector end_3{1};
444   ShapeVector stride_3{2};
445   mindspore::CNodePtr slice_c_node_3 =
446     add_stride_slice_node(func_graph, begin_3, stride_3, stride_3, 0, 0, slice_c_node_1);
447   ShapeVector slice_3_shape{1};
448   slice_c_node_3->set_abstract(std::make_shared<Tensor>(data_type, slice_3_shape)->ToAbstract());
449   func_graph->AddNode(slice_c_node_3);
450 
451   // add opaque predicate
452   PrimitivePtr custom_prim = mindspore::prim::kPrimOpaquePredicate;
453   custom_prim->set_attr("is_load", MakeValue(true));
454   std::vector<ValuePtr> input_names_value;
455   input_names_value.push_back(std::make_shared<StringImm>("x"));
456   input_names_value.push_back(std::make_shared<StringImm>("y"));
457   custom_prim->set_attr(mindspore::kAttrInputNames, std::make_shared<ValueList>(input_names_value));
458   std::vector<ValuePtr> output_names_value;
459   output_names_value.push_back(std::make_shared<StringImm>("output"));
460   custom_prim->set_attr(mindspore::kAttrOutputNames, std::make_shared<ValueList>(output_names_value));
461   auto opaque_v_node = std::make_shared<mindspore::ValueNode>(custom_prim);
462   func_graph->AddValueNode(opaque_v_node);
463   auto opaque_c_node = func_graph->NewCNode({opaque_v_node, slice_c_node_2, slice_c_node_3});
464   ShapeVector y_shape{1};
465   auto bool_tensor = std::make_shared<Tensor>(mindspore::kNumberTypeBool, y_shape);
466   opaque_c_node->set_abstract(bool_tensor->ToAbstract());
467   func_graph->AddNode(opaque_c_node);
468   return opaque_c_node;
469 }
470 
GetControlNode(const FuncGraphPtr & func_graph,const AnfNodePtr & prev_node)471 CNodePtr DynamicObfuscator::GetControlNode(const FuncGraphPtr &func_graph, const AnfNodePtr &prev_node) {
472   MS_EXCEPTION_IF_NULL(func_graph);
473   MS_EXCEPTION_IF_NULL(prev_node);
474   if (branch_control_input_ != 0) {
475     MS_LOG(INFO) << "Run password mode.";
476     return RandomSeedModeControl(func_graph);
477   }
478   MS_LOG(INFO) << "Run customized function mode.";
479   if (prev_node != nullptr && prev_node->abstract() != nullptr) {
480     return CustomOpModeControl(func_graph, prev_node);
481   }
482   return nullptr;
483 }
484 
get_random_prim(const std::string & obf_type,const mindspore::CNodePtr & node)485 mindspore::PrimitivePtr DynamicObfuscator::get_random_prim(const std::string &obf_type,
486                                                            const mindspore::CNodePtr &node) {
487   std::vector<string> split_words = name_split(obf_type, "-");
488   if (split_words.empty()) {
489     MS_LOG(WARNING) << "obf_type is empty.";
490     return nullptr;
491   }
492   std::string prim_name_ori = split_words[0];
493   mindspore::PrimitivePtr poolptr = nullptr;
494   if (prim_name_ori == kMaxPoolOpName || prim_name_ori == kAvgPoolOpName) {
495     if (prim_name_ori == kMaxPoolOpName) {
496       poolptr = std::make_shared<Primitive>("AvgPool");
497     } else {
498       poolptr = std::make_shared<Primitive>("MaxPool");
499     }
500     auto primitive = GetCNodePrimitive(node);
501     MS_EXCEPTION_IF_NULL(primitive);
502     MS_EXCEPTION_IF_NULL(primitive->GetAttr("input_names"));
503     MS_EXCEPTION_IF_NULL(primitive->GetAttr("output_names"));
504     MS_EXCEPTION_IF_NULL(primitive->GetAttr("format"));
505     MS_EXCEPTION_IF_NULL(primitive->GetAttr("kernel_size"));
506     MS_EXCEPTION_IF_NULL(primitive->GetAttr("strides"));
507     poolptr->set_attr("input_names", primitive->GetAttr("input_names"));
508     poolptr->set_attr("output_names", primitive->GetAttr("output_names"));
509     poolptr->set_attr("format", primitive->GetAttr("format"));
510     poolptr->set_attr("pad_mode", primitive->GetAttr("pad_mode"));
511     poolptr->set_attr("kernel_size", primitive->GetAttr("kernel_size"));
512     poolptr->set_attr("strides", primitive->GetAttr("strides"));
513     return poolptr;
514   }
515   mindspore::PrimitivePtr prim_node = one_input_prim_[0];
516   do {
517     int random = rand() % SizeToInt(one_input_prim_.size());
518     prim_node = one_input_prim_[random];
519   } while (prim_name_ori == prim_node->ToString());
520   return prim_node;
521 }
522 
UpdateDict(const AnfNodePtr & node,const bool isParent)523 void DynamicObfuscator::UpdateDict(const AnfNodePtr &node, const bool isParent) {
524   if (node == nullptr) {
525     MS_LOG(ERROR) << "Input node is nullptr, update dict failed.";
526     return;
527   }
528   MS_LOG(INFO) << "Update: " << node->fullname_with_scope() << " to dict.";
529   if (isParent) {
530     parent_names_.push(node->fullname_with_scope());
531   } else {
532     node_names_.push(node->fullname_with_scope());
533     subgraph_obf_num_++;
534   }
535   node_dict_[node->fullname_with_scope()] = node->cast<mindspore::AnfNodePtr>();
536   if (node_dict_[node->fullname_with_scope()] == nullptr) {
537     MS_LOG(ERROR) << "Update node " << node->fullname_with_scope() << " failed.";
538   }
539 }
540 
CheckDuplicatedParent(const AnfNodePtr & node)541 void DynamicObfuscator::CheckDuplicatedParent(const AnfNodePtr &node) {
542   if (node == nullptr) {
543     MS_LOG(ERROR) << "Input node is nullptr, check parent failed.";
544     return;
545   }
546   if (node_dict_.find(node->fullname_with_scope()) != node_dict_.cend()) {
547     while (node_names_.top() != "-") {
548       node_dict_.erase(node_names_.top());
549       node_names_.pop();
550       subgraph_obf_num_--;
551     }
552   } else {
553     node_names_.push("-");
554     UpdateDict(node, true);
555     if (branch_control_input_ == 0) {
556       bool customized_func_result = mindspore::kernel::CustomizedOpaquePredicate::GetInstance().run_function(
557         static_cast<float>(1), static_cast<float>(1));
558       customized_func_results_.push_back(customized_func_result);
559     }
560   }
561 }
562 
IsTarget(const std::string & cnode_name)563 bool DynamicObfuscator::IsTarget(const std::string &cnode_name) {
564   if (cnode_name.empty()) {
565     MS_LOG(INFO) << "CNode name is empty.";
566     return false;
567   }
568   std::vector<std::string> target_op_list;
569   target_op_list.insert(target_op_list.end(), single_input_target_op_.begin(), single_input_target_op_.end());
570   target_op_list.insert(target_op_list.end(), single_input_with_weight_target_op_.begin(),
571                         single_input_with_weight_target_op_.end());
572   if (std::find(target_op_list.cbegin(), target_op_list.cend(), cnode_name) != target_op_list.cend()) {
573     return true;
574   }
575   return false;
576 }
577 
CheckInputNodes(const mindspore::CNodePtr & node)578 mindspore::CNodePtr DynamicObfuscator::CheckInputNodes(const mindspore::CNodePtr &node) {
579   if (node == nullptr) {
580     MS_LOG(ERROR) << "Input node is nullptr, check input failed.";
581     return nullptr;
582   }
583   auto node_inputs = node->inputs();
584   for (auto input_node : node_inputs) {
585     std::string cnode_name = get_node_prim_name(input_node);
586     if (IsTarget(cnode_name)) {
587       return input_node->cast<mindspore::CNodePtr>();
588     }
589   }
590   return nullptr;
591 }
592 
BuildOneInputNoWeightNode(const FuncGraphPtr & fg,const mindspore::AnfNodePtr & input_node,const mindspore::PrimitivePtr prim_node) const593 mindspore::CNodePtr DynamicObfuscator::BuildOneInputNoWeightNode(const FuncGraphPtr &fg,
594                                                                  const mindspore::AnfNodePtr &input_node,
595                                                                  const mindspore::PrimitivePtr prim_node) const {
596   if (input_node == nullptr) {
597     MS_LOG(ERROR) << "Build Node failed: input node is nullptr.";
598     return nullptr;
599   }
600   if (fg == nullptr) {
601     MS_LOG(ERROR) << "Build Node failed: FuncGraph is nullptr.";
602     return nullptr;
603   }
604   if (prim_node == nullptr) {
605     MS_LOG(ERROR) << "Build Node failed: prim_node is nullptr.";
606     return nullptr;
607   }
608   std::vector<ValuePtr> input_names_value;
609   input_names_value.emplace_back(std::make_shared<StringImm>("x"));
610   prim_node->set_attr("is_load", MakeValue(true));
611   prim_node->set_attr(mindspore::kAttrInputNames, std::make_shared<ValueList>(input_names_value));
612   mindspore::ValueNodePtr v_node = std::make_shared<mindspore::ValueNode>(prim_node);
613   fg->AddValueNode(v_node);
614   mindspore::CNodePtr c_node = fg->NewCNode({v_node, input_node});
615   if (c_node == nullptr) {
616     MS_LOG(ERROR) << "Build node failed: cnode is nullptr.";
617     return nullptr;
618   }
619   ShapeVector x_shape = get_node_shape(input_node);
620   TypeId type_id = get_node_dtype(input_node);
621   auto node_abstract = std::make_shared<Tensor>(type_id, x_shape)->ToAbstract();
622   if (node_abstract == nullptr) {
623     MS_LOG(ERROR) << "Build node failed: node abstract is nullptr.";
624     return nullptr;
625   }
626   c_node->set_abstract(node_abstract);
627   fg->AddNode(c_node);
628   return c_node;
629 }
630 
BuildOneInputWithWeightNode(const FuncGraphPtr & fg,const mindspore::AnfNodePtr & input_node,const mindspore::CNodePtr & node,const mindspore::AnfNodePtr & weights) const631 mindspore::CNodePtr DynamicObfuscator::BuildOneInputWithWeightNode(const FuncGraphPtr &fg,
632                                                                    const mindspore::AnfNodePtr &input_node,
633                                                                    const mindspore::CNodePtr &node,
634                                                                    const mindspore::AnfNodePtr &weights) const {
635   if (node == nullptr) {
636     MS_LOG(ERROR) << "Build one input with weight node failed: node is nullptr.";
637     return nullptr;
638   }
639   std::string node_name = node->fullname_with_scope();
640   if (input_node == nullptr) {
641     MS_LOG(ERROR) << "Build " << node_name << " failed: input node is nullptr.";
642     return nullptr;
643   }
644   if (fg == nullptr) {
645     MS_LOG(ERROR) << "Build " << node_name << " failed: FuncGraph is nullptr.";
646     return nullptr;
647   }
648   if (weights == nullptr) {
649     MS_LOG(ERROR) << "Build " << node_name << " failed: weights is nullptr.";
650     return nullptr;
651   }
652   std::vector<AnfNodePtr> node_inputs = node->inputs();
653   if (node_inputs.size() < 1) {
654     MS_LOG(ERROR) << "Build " << node_name << " failed: inputs size is 0";
655     return nullptr;
656   }
657   mindspore::ValueNodePtr v_node = node_inputs[0]->cast<mindspore::ValueNodePtr>();
658   fg->AddValueNode(v_node);
659 
660   mindspore::CNodePtr c_node = fg->NewCNode({v_node, input_node, weights});
661   if (c_node == nullptr) {
662     MS_LOG(ERROR) << "Build " << node_name << " failed: cnode is nullptr.";
663     return nullptr;
664   }
665   ShapeVector x_shape = get_node_shape(node);
666   TypeId type_id = get_node_dtype(node);
667   auto node_abstract = std::make_shared<Tensor>(type_id, x_shape)->ToAbstract();
668   if (node_abstract == nullptr) {
669     MS_LOG(ERROR) << "Build " << node_name << " failed: abstract is nullptr.";
670     return nullptr;
671   }
672   c_node->set_abstract(node_abstract);
673   (void)fg->AddNode(c_node);
674   return c_node;
675 }
676 
CloneSubGraph(const std::vector<mindspore::CNodePtr> & node_arr,const mindspore::AnfNodePtr & parent_node)677 FuncGraphPtr DynamicObfuscator::CloneSubGraph(const std::vector<mindspore::CNodePtr> &node_arr,
678                                               const mindspore::AnfNodePtr &parent_node) {
679   MS_LOG(INFO) << "Building Clone Graph ";
680   mindspore::FuncGraphPtr fg_clone = std::make_shared<FuncGraph>();
681   ShapeVector x_shape = get_node_shape(parent_node);
682   TypeId x_type_id = get_node_dtype(parent_node);
683   MS_LOG(INFO) << "Get Shape Input X";
684 
685   mindspore::ParameterPtr input_x = fg_clone->add_parameter();
686   if (input_x == nullptr) {
687     MS_LOG(ERROR) << "Build clone graph failed: input_x is nullptr.";
688     return nullptr;
689   }
690   input_x->set_name("input_x_clone");
691   tensor::TensorPtr input_x_tensor = std::make_shared<Tensor>(x_type_id, x_shape);
692   input_x->set_abstract(input_x_tensor->ToAbstract());
693   mindspore::AnfNodePtr last_node = input_x;
694   for (auto node : node_arr) {
695     std::string obf_type = ObfuscateOpType(node);
696     MS_LOG(INFO) << "obf_type: " << obf_type;
697     mindspore::ObfCase obf_case = ObfuscateOpCase(obf_type);
698     switch (obf_case) {
699       case ObfCase::OneInputNoWeightNode: {
700         mindspore::PrimitivePtr prim_node = GetCNodePrimitive(node);
701         last_node = BuildOneInputNoWeightNode(fg_clone, last_node, prim_node);
702         if (last_node == nullptr) {
703           MS_LOG(ERROR) << "Last node after build is nullptr.";
704           return nullptr;
705         }
706         break;
707       }
708       case ObfCase::OneInputWithWeightNode: {
709         mindspore::ParameterPtr weight_param = fg_clone->add_parameter();
710         if (weight_param == nullptr) {
711           MS_LOG(ERROR) << "Build OneInputWithWeightNode failed: weights is nullptr.";
712           return nullptr;
713         }
714         weight_param->set_name("OneInputWithWeightNode_clone");
715         last_node = BuildOneInputWithWeightNode(fg_clone, last_node, node, weight_param);
716         if (last_node == nullptr) {
717           MS_LOG(ERROR) << "Last node after build is nullptr.";
718           return nullptr;
719         }
720         break;
721       }
722       case ObfCase::NotObfNode: {
723         MS_LOG(ERROR) << "The current node does not belong to target nodes.";
724       }
725       default:
726         return nullptr;
727     }
728   }
729 
730   mindspore::ValueNodePtr return_v = std::make_shared<mindspore::ValueNode>(mindspore::prim::kPrimReturn);
731   fg_clone->AddValueNode(return_v);
732   mindspore::CNodePtr return_c_node = fg_clone->NewCNode({return_v, last_node});
733   if (return_c_node == nullptr) {
734     MS_LOG(ERROR) << "Build return failed: return cnode is nullptr.";
735     return nullptr;
736   }
737   ShapeVector return_shape = get_node_shape(last_node->cast<mindspore::CNodePtr>());
738   TypeId type_id = get_node_dtype(last_node->cast<mindspore::CNodePtr>());
739   auto return_abstract = std::make_shared<Tensor>(type_id, return_shape)->ToAbstract();
740   if (return_abstract == nullptr) {
741     MS_LOG(ERROR) << "Build return failed: return abstract is nullptr.";
742     return nullptr;
743   }
744   return_c_node->set_abstract(return_abstract);
745   fg_clone->AddNode(return_c_node);
746   fg_clone->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, true);
747   fg_clone->set_return(return_c_node);
748   return fg_clone;
749 }
750 
BuildFakeGraph(const std::vector<mindspore::CNodePtr> & node_arr,const mindspore::AnfNodePtr & parent_node)751 FuncGraphPtr DynamicObfuscator::BuildFakeGraph(const std::vector<mindspore::CNodePtr> &node_arr,
752                                                const mindspore::AnfNodePtr &parent_node) {
753   MS_LOG(INFO) << "Building Fake Graph ";
754   mindspore::FuncGraphPtr fg_fake = std::make_shared<FuncGraph>();
755 
756   ShapeVector x_shape = get_node_shape(parent_node);
757   TypeId x_type_id = get_node_dtype(parent_node);
758   mindspore::ParameterPtr input_x = fg_fake->add_parameter();
759   if (input_x == nullptr) {
760     MS_LOG(ERROR) << "Build fake graph failed: input_x is nullptr.";
761     return nullptr;
762   }
763   input_x->set_name("input_x_fake");
764   tensor::TensorPtr input_x_tensor = std::make_shared<Tensor>(x_type_id, x_shape);
765   input_x->set_abstract(input_x_tensor->ToAbstract());
766   mindspore::AnfNodePtr last_node = input_x;
767   for (auto node : node_arr) {
768     std::string obf_type = ObfuscateOpType(node);
769     mindspore::ObfCase obf_case = ObfuscateOpCase(obf_type);
770     switch (obf_case) {
771       case ObfCase::OneInputNoWeightNode: {
772         mindspore::PrimitivePtr prim_node = get_random_prim(obf_type, node);
773         last_node = BuildOneInputNoWeightNode(fg_fake, last_node, prim_node);
774         if (last_node == nullptr) {
775           MS_LOG(ERROR) << "Last node after build is nullptr.";
776           return nullptr;
777         }
778         break;
779       }
780       case ObfCase::OneInputWithWeightNode: {
781         mindspore::AnfNodePtr ori_vnode = node->cast<mindspore::CNodePtr>()->inputs()[2];
782         TypeId type_id = get_node_dtype(ori_vnode);
783         ShapeVector shape = get_node_shape(ori_vnode);
784         tensor::TensorPtr weight_tensor = make_weight_tensor(type_id, shape);
785         mindspore::ValueNodePtr weight_vnode = std::make_shared<mindspore::ValueNode>(weight_tensor);
786         if (weight_vnode == nullptr) {
787           MS_LOG(ERROR) << "Build OneInputWithWeightNode failed: value node is nullptr.";
788           return nullptr;
789         }
790         weight_vnode->set_abstract(weight_tensor->ToAbstract());
791         fg_fake->AddValueNode(weight_vnode);
792         last_node = BuildOneInputWithWeightNode(fg_fake, last_node, node, weight_vnode);
793         if (last_node == nullptr) {
794           MS_LOG(ERROR) << "Last node after build is nullptr.";
795           return nullptr;
796         }
797         break;
798       }
799       case ObfCase::NotObfNode: {
800         MS_LOG(ERROR) << "The current node is not obf-target";
801       }
802       default:
803         return nullptr;
804     }
805   }
806 
807   mindspore::ValueNodePtr return_v = std::make_shared<mindspore::ValueNode>(mindspore::prim::kPrimReturn);
808   fg_fake->AddValueNode(return_v);
809   mindspore::CNodePtr return_c_node = fg_fake->NewCNode({return_v, last_node});
810   if (return_c_node == nullptr) {
811     MS_LOG(ERROR) << "Build return failed: return cnode is nullptr.";
812     return nullptr;
813   }
814   ShapeVector return_shape = get_node_shape(last_node->cast<mindspore::CNodePtr>());
815   TypeId type_id = get_node_dtype(last_node->cast<mindspore::CNodePtr>());
816   auto return_abstract = std::make_shared<Tensor>(type_id, return_shape)->ToAbstract();
817   if (return_abstract == nullptr) {
818     MS_LOG(ERROR) << "Build return failed: return abstract is nullptr.";
819     return nullptr;
820   }
821   return_c_node->set_abstract(return_abstract);
822   fg_fake->AddNode(return_c_node);
823   fg_fake->set_return(return_c_node);
824   fg_fake->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, true);
825   return fg_fake;
826 }
827 
AddPartialBranch(const FuncGraphPtr fg,FuncGraphPtr fg_sub,const std::vector<mindspore::CNodePtr> & nodes)828 mindspore::CNodePtr DynamicObfuscator::AddPartialBranch(const FuncGraphPtr fg, FuncGraphPtr fg_sub,
829                                                         const std::vector<mindspore::CNodePtr> &nodes) {
830   if (fg == nullptr) {
831     MS_LOG(ERROR) << "Add subgraph failed: fg is null.";
832     return nullptr;
833   }
834   if (fg_sub == nullptr) {
835     MS_LOG(ERROR) << "Add subgraph failed: fg_sub is null.";
836     return nullptr;
837   }
838   if (nodes.size() == 0) {
839     MS_LOG(ERROR) << "Add subgraph failed: input nodes size is 0.";
840     return nullptr;
841   }
842 
843   mindspore::ValueNodePtr switch_partial = std::make_shared<mindspore::ValueNode>(mindspore::prim::kPrimPartial);
844   fg->AddValueNode(switch_partial);
845   mindspore::ValueNodePtr fg_subgraph_node = std::make_shared<mindspore::ValueNode>(fg_sub);
846   fg_subgraph_node->set_abstract(fg_sub->ToAbstract());
847   fg->AddValueNode(fg_subgraph_node);
848   std::vector<mindspore::AnfNodePtr> subgraph_inputs = {switch_partial, fg_subgraph_node};
849   if (nodes[0]->size() < kSwitchInputsNum) {
850     MS_LOG(ERROR) << "Add subgraph failed: the input number of node[0] is smaller than " << kSwitchInputsNum;
851     return nullptr;
852   }
853   subgraph_inputs.push_back(nodes[0]->inputs()[1]);
854   size_t func_params_num = fg_sub->parameters().size();
855   size_t pushed_inputs = 1;
856   for (unsigned i = 0; i < nodes.size(); i++) {
857     if (pushed_inputs >= func_params_num) {
858       break;
859     }
860     std::string obf_type = ObfuscateOpType(nodes[i]);
861     if ((obf_type == kConv2DOpName || obf_type == kMatMulOpName) && nodes[i]->size() >= kNodeWithWeightInputsNum) {
862       subgraph_inputs.push_back(nodes[i]->inputs()[kWeightIndex]);
863       pushed_inputs += 1;
864     }
865   }
866   mindspore::CNodePtr switch_partial_c = fg->NewCNode(subgraph_inputs);
867   if (switch_partial_c == nullptr) {
868     MS_LOG(ERROR) << "Add subgraph failed: switch partial is null.";
869     return nullptr;
870   }
871   switch_partial_c->set_abstract(fg_sub->ToAbstract());
872   fg->AddNode(switch_partial_c);
873   return switch_partial_c;
874 }
875 
AddSwitchNode(const FuncGraphPtr fg)876 void DynamicObfuscator::AddSwitchNode(const FuncGraphPtr fg) {
877   if (fg == nullptr) {
878     MS_LOG(ERROR) << "Build switch failed: FuncGraph is nullptr.";
879     return;
880   }
881   int switch_num_ = 0;
882   while (!parent_names_.empty()) {
883     auto mgr = mindspore::Manage(fg);
884     if (mgr == nullptr) {
885       MS_LOG(ERROR) << "FuncGraph manager is nullptr.";
886       return;
887     }
888     std::vector<mindspore::CNodePtr> nodes;
889     mindspore::AnfNodePtr last_node = nullptr;
890     mindspore::CNodePtr child_node = nullptr;
891     while (node_names_.top() != "-") {
892       MS_LOG(INFO) << "Processing sub_graph node: " << node_names_.top();
893       last_node = node_dict_[node_names_.top()];
894       nodes.push_back(last_node->cast<mindspore::CNodePtr>());
895       node_names_.pop();  // pop '-'
896     }
897     node_names_.pop();
898     if (mgr->node_users().find(last_node) != mgr->node_users().cend()) {
899       auto users = mgr->node_users()[last_node];
900       child_node = users.cbegin()->first->cast<mindspore::CNodePtr>();
901     } else {
902       MS_LOG(WARNING) << "Child Node of " << last_node->fullname_with_scope() << " is nullptr.";
903     }
904     mindspore::AnfNodePtr parent_node = node_dict_[parent_names_.top()];
905     parent_names_.pop();
906 
907     mindspore::FuncGraphPtr fg_subgraph_clone = CloneSubGraph(nodes, parent_node);
908     mindspore::FuncGraphPtr fg_subgraph_fake = BuildFakeGraph(nodes, parent_node);
909 
910     mgr->AddFuncGraph(fg_subgraph_clone);
911     mgr->AddFuncGraph(fg_subgraph_fake);
912 
913     mindspore::CNodePtr switch_partial_clone_c = AddPartialBranch(fg, fg_subgraph_clone, nodes);
914     mindspore::CNodePtr switch_partial_fake_c = AddPartialBranch(fg, fg_subgraph_fake, nodes);
915     if (switch_partial_clone_c == nullptr || switch_partial_fake_c == nullptr) {
916       continue;
917     }
918 
919     CNodePtr control_node = GetControlNode(fg, parent_node);
920     if (control_node == nullptr) {
921       continue;
922     }
923 
924     mindspore::ValueNodePtr switch_v_node = std::make_shared<mindspore::ValueNode>(mindspore::prim::kPrimSwitch);
925     fg->AddValueNode(switch_v_node);
926     mindspore::CNodePtr switch_c_node;
927     if (branch_control_input_ == 0) {
928       if (static_cast<int>(customized_func_results_.size()) <= used_control_node_) {
929         MS_LOG(ERROR) << "customized_func_results_ size is smaller than used_control_node_.";
930       }
931       switch_branch_ = customized_func_results_[used_control_node_];
932       used_control_node_ += 1;
933     }
934     if (switch_branch_) {
935       switch_c_node = fg->NewCNode({switch_v_node, control_node, switch_partial_clone_c, switch_partial_fake_c});
936     } else {
937       switch_c_node = fg->NewCNode({switch_v_node, control_node, switch_partial_fake_c, switch_partial_clone_c});
938     }
939     if (switch_c_node == nullptr) {
940       MS_LOG(ERROR) << "switch_c_node is nullptr.";
941       return;
942     }
943     switch_c_node->set_abstract(fg_subgraph_clone->ToAbstract());
944     fg->AddNode(switch_c_node);
945 
946     mindspore::CNodePtr call_cnode = fg->NewCNode({switch_c_node});
947     if (call_cnode == nullptr) {
948       MS_LOG(ERROR) << "call_cnode is nullptr.";
949       return;
950     }
951     fg->AddNode(call_cnode);
952 
953     if (child_node != nullptr) {
954       unsigned i = 0;
955       for (auto &weak_input : child_node->weak_inputs()) {
956         auto input = weak_input.lock();
957         MS_EXCEPTION_IF_NULL(input);
958         if (input->fullname_with_scope() == last_node->fullname_with_scope()) {
959           child_node->set_input(i, call_cnode);
960           break;
961         }
962         i++;
963       }
964       switch_num_++;
965     }
966   }
967   MS_LOG(WARNING) << switch_num_ << " switch nodes have been added.";
968   used_control_node_ = 0;
969 }
970 
GetNodeMaxNum(const AnfNodeSet nodes)971 int GetNodeMaxNum(const AnfNodeSet nodes) {
972   int node_max_num = 0;
973   for (auto node : nodes) {
974     if (node != nullptr && node->isa<CNode>()) {
975       int op_num = get_op_num(node);
976       if (op_num > node_max_num) {
977         node_max_num = op_num;
978       }
979     }
980   }
981   return node_max_num;
982 }
983 
NodePrepareCheck(const mindspore::AnfNodePtr & node,const int & branch_control_input)984 bool NodePrepareCheck(const mindspore::AnfNodePtr &node, const int &branch_control_input) {
985   std::string ignore_name = "down_sample_layer";
986   if (node == nullptr) {
987     MS_LOG(INFO) << "Find null node!" << std::endl;
988     return false;
989   }
990   if (!node->isa<CNode>()) {
991     MS_LOG(INFO) << "Not a Cnode." << std::endl;
992     return false;
993   }
994   // Ignore ResNet's down_sample_layer node for customized func mode.
995   if ((branch_control_input == 0) && (node->fullname_with_scope().find(ignore_name) != std::string::npos)) {
996     MS_LOG(INFO) << "Find down_sample_layer node: " << node->fullname_with_scope() << std::endl;
997     return false;
998   }
999   return true;
1000 }
1001 
IsValidOpNum(const int & current_num,const int & compa_num) const1002 bool DynamicObfuscator::IsValidOpNum(const int &current_num, const int &compa_num) const {
1003   if (branch_control_input_ != 0) {
1004     return true;
1005   }
1006   return current_num <= compa_num;
1007 }
1008 
SubGraphFakeBranch(const FuncGraphPtr func_graph)1009 void DynamicObfuscator::SubGraphFakeBranch(const FuncGraphPtr func_graph) {
1010   if (func_graph == nullptr) {
1011     MS_LOG(ERROR) << "Build fake sub-graph failed: FuncGraph is nullptr.";
1012     return;
1013   }
1014   node_names_.push("-");
1015   auto mgr = mindspore::Manage(func_graph);
1016   if (mgr == nullptr) {
1017     MS_LOG(ERROR) << "Manager is null node!";
1018     return;
1019   }
1020   auto all_nodes = mgr->all_nodes();
1021   int node_nums = SizeToInt(all_nodes.size());
1022   int obfuscate_target_num = std::ceil(node_nums * obf_ratio_ / keyExpandRate);
1023   int op_num = GetNodeMaxNum(all_nodes);
1024   MS_LOG(INFO) << "Init op_num is: " << op_num;
1025   std::vector<mindspore::AnfNodePtr> sorted_nodes;
1026   for (auto node : all_nodes) {
1027     MS_LOG(INFO) << "The last node name is: " << node->fullname_with_scope();
1028     sorted_nodes = TopoSort(node);  // the node number in front of sorted nodes is the smallest
1029     break;
1030   }
1031   std::reverse(sorted_nodes.begin(), sorted_nodes.end());
1032   for (auto node : sorted_nodes) {
1033     if (!NodePrepareCheck(node, branch_control_input_)) {
1034       continue;
1035     }
1036     std::string cnode_name = get_node_prim_name(node);
1037     MS_LOG(INFO) << "CNode name is: " << cnode_name;
1038     int cur_op_num = get_op_num(node);
1039     float dropout_rate = 0.1;
1040     int dropout_rand = rand() % static_cast<int>(1.0 / dropout_rate);
1041     if (IsTarget(cnode_name) && IsValidOpNum(cur_op_num, op_num) && dropout_rand != 0 &&
1042         (node_dict_.find(node->fullname_with_scope()) == node_dict_.cend())) {
1043       UpdateDict(node, false);
1044       op_num = cur_op_num;
1045       bool stop_traverse = false;
1046       mindspore::CNodePtr curr_cnode = node->cast<mindspore::CNodePtr>();
1047       while (!stop_traverse) {
1048         mindspore::CNodePtr valid_input = CheckInputNodes(curr_cnode);
1049         dropout_rand = rand() % static_cast<int>(1.0 / dropout_rate);
1050         if (valid_input && dropout_rand != 0 &&
1051             (node_dict_.find(valid_input->fullname_with_scope()) == node_dict_.cend())) {
1052           UpdateDict(valid_input, false);
1053           op_num = get_op_num(valid_input);
1054           curr_cnode = valid_input;
1055         } else {
1056           stop_traverse = true;
1057           if (curr_cnode->size() > 1) {
1058             CheckDuplicatedParent(curr_cnode->inputs()[1]);
1059           }
1060         }
1061       }
1062     }
1063     if (subgraph_obf_num_ >= obfuscate_target_num) {
1064       break;
1065     }
1066   }
1067   node_names_.pop();
1068   if (branch_control_input_ == 0) {
1069     mindspore::kernel::CustomizedOpaquePredicate::GetInstance().init_calling_count();
1070   }
1071   AddSwitchNode(func_graph);
1072   MS_LOG(WARNING) << subgraph_obf_num_ << " nodes have been obfuscated.";
1073 }
1074 }  // namespace mindspore
1075