1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "include/common/utils/dynamic_obfuscation/dynamic_obfuscation.h"
17 #include <algorithm>
18 #include <functional>
19 #include <map>
20 #include <memory>
21 #include <random>
22 #include "ops/conv_pool_op_name.h"
23 #include "ops/math_op_name.h"
24 #include "ops/other_ops.h"
25 #include "ops/comparison_ops.h"
26 #include "ops/array_ops.h"
27 #include "ops/auto_generate/gen_ops_primitive.h"
28 #include "ops/framework_ops.h"
29 #include "include/common/debug/anf_ir_dump.h"
30 #include "include/common/utils/dynamic_obfuscation/registry_opaque_predicate.h"
31 #include "include/common/utils/utils.h"
32 #include "ir/anf.h"
33 #include "ir/tensor.h"
34 #include "utils/info.h"
35
36 namespace mindspore {
37 namespace {
AddObfuscatedParam(FuncGraphPtr func_graph)38 ParameterPtr AddObfuscatedParam(FuncGraphPtr func_graph) {
39 auto params = func_graph->parameters();
40 auto add_param = std::make_shared<Parameter>(func_graph);
41 std::vector<AnfNodePtr> new_para_list(params.begin(), params.begin() + params.size() - func_graph->fv_param_count());
42 (void)new_para_list.emplace_back(add_param);
43 (void)new_para_list.insert(new_para_list.cend(), params.begin() + params.size() - func_graph->fv_param_count(),
44 params.end());
45 func_graph->set_parameters(new_para_list);
46 return add_param;
47 }
48 } // namespace
49 using Tensor = mindspore::tensor::Tensor;
50 using mindspore::abstract::AbstractTensor;
51 using mindspore::abstract::AbstractTensorPtr;
52 using mindspore::abstract::AbstractTuple;
53 using mindspore::abstract::AbstractTuplePtr;
54
55 constexpr int keyExpandRate = 10; // total node need for a switch graph
56 constexpr int kWeightIndex = 2;
57 constexpr int kSwitchInputsNum = 2;
58 constexpr int kNodeWithWeightInputsNum = 3;
59
get_node_shape(const AnfNodePtr & input_node)60 ShapeVector get_node_shape(const AnfNodePtr &input_node) {
61 if (input_node == nullptr) {
62 MS_LOG(ERROR) << "Input node is nullptr, get shape failed!";
63 return {};
64 }
65 AbstractBasePtr input_abstract = input_node->abstract();
66 if (input_abstract == nullptr) {
67 MS_LOG(ERROR) << "The abstract of input_node is nullptr, get shape failed!";
68 return {};
69 }
70 AbstractTensorPtr input_abstract_tensor = input_abstract->cast<mindspore::abstract::AbstractTensorPtr>();
71 MS_EXCEPTION_IF_NULL(input_abstract_tensor);
72 mindspore::abstract::ShapePtr shape_ptr = input_abstract_tensor->shape();
73 if (shape_ptr == nullptr) {
74 return {};
75 }
76 return shape_ptr->shape();
77 }
78
get_node_dtype(const AnfNodePtr & input_node)79 TypeId get_node_dtype(const AnfNodePtr &input_node) {
80 if (input_node == nullptr) {
81 MS_LOG(ERROR) << "Input node is nullptr, get dtype failed!";
82 return {};
83 }
84 AbstractBasePtr input_abstract = input_node->abstract();
85 if (input_abstract == nullptr) {
86 MS_LOG(ERROR) << "The abstract of input_node is nullptr, get dtype failed!";
87 return {};
88 }
89 AbstractTensorPtr input_abstract_tensor = input_abstract->cast<mindspore::abstract::AbstractTensorPtr>();
90 MS_EXCEPTION_IF_NULL(input_abstract_tensor);
91 AbstractBasePtr node_element = input_abstract_tensor->element();
92 mindspore::abstract::AbstractScalarPtr node_element_abs =
93 node_element->cast<mindspore::abstract::AbstractScalarPtr>();
94 MS_EXCEPTION_IF_NULL(node_element_abs);
95 TypeId data_type = node_element_abs->BuildType()->type_id();
96 return data_type;
97 }
98
name_split(const std::string & node_name_,const std::string & split_sign)99 std::vector<std::string> name_split(const std::string &node_name_, const std::string &split_sign) {
100 std::string node_name = node_name_;
101 node_name += split_sign;
102 unsigned int name_len = node_name.size();
103 std::string::size_type split_pos;
104 std::vector<std::string> res;
105 for (unsigned int i = 0; i < name_len; i++) {
106 split_pos = node_name.find(split_sign, i);
107 if (split_pos < name_len) {
108 std::string sub_str = node_name.substr(i, split_pos - i);
109 res.push_back(sub_str);
110 i = split_pos + SizeToUint(split_sign.size()) - 1;
111 }
112 }
113 return res;
114 }
115
get_node_prim_name(const AnfNodePtr & node)116 std::string get_node_prim_name(const AnfNodePtr &node) {
117 if (node == nullptr) {
118 MS_LOG(ERROR) << "Input node is nullptr, get name failed!";
119 return "";
120 }
121 PrimitivePtr node_prim = GetCNodePrimitive(node);
122 if (node_prim == nullptr) {
123 MS_LOG(DEBUG) << "The primitive of node " << node->fullname_with_scope() << " is nullptr!";
124 return "";
125 }
126 return node_prim->ToString();
127 }
128
get_op_num(const AnfNodePtr & node)129 int get_op_num(const AnfNodePtr &node) {
130 if (node == nullptr) {
131 MS_LOG(ERROR) << "Input node is nullptr, get name failed!";
132 return 0;
133 }
134 std::string node_name = node->fullname_with_scope();
135 std::vector<string> split_words = name_split(node_name, "op");
136 if (split_words.empty()) {
137 MS_LOG(WARNING) << "Input node name is empty.";
138 return 0;
139 }
140 std::string num = split_words[split_words.size() - 1];
141 return std::stoi(num);
142 }
143
get_node_param(const FuncGraphPtr func_graph,const CNodePtr & node)144 ParameterPtr get_node_param(const FuncGraphPtr func_graph, const CNodePtr &node) {
145 if (node == nullptr) {
146 MS_LOG(ERROR) << "Node is nullptr, get param failed!";
147 return nullptr;
148 }
149 if (func_graph == nullptr) {
150 MS_LOG(ERROR) << "FuncGraph is nullptr, get param failed!";
151 return nullptr;
152 }
153 std::string parameter_name = "";
154 for (auto &weak_input : node->weak_inputs()) {
155 auto input = weak_input.lock();
156 MS_EXCEPTION_IF_NULL(input);
157 std::string op_name = get_node_prim_name(input);
158 MS_LOG(INFO) << "op_name is: " << op_name;
159 if (op_name == "Load") {
160 for (auto weak_param : input->cast<mindspore::CNodePtr>()->weak_inputs()) {
161 auto param = weak_param.lock();
162 MS_EXCEPTION_IF_NULL(param);
163 if (param->fullname_with_scope().find("weight") != std::string::npos) {
164 parameter_name = param->fullname_with_scope();
165 break;
166 }
167 }
168 }
169 }
170 for (auto param : func_graph->parameters()) {
171 auto param_node = param->cast<mindspore::ParameterPtr>();
172 if (param_node == nullptr) {
173 MS_LOG(ERROR) << "Param node is nullptr.";
174 return nullptr;
175 }
176 if (param->fullname_with_scope() == parameter_name) {
177 return param_node;
178 }
179 }
180 return nullptr;
181 }
182
build_tuple_value_node(const std::vector<int64_t> & values)183 ValueNodePtr build_tuple_value_node(const std::vector<int64_t> &values) {
184 mindspore::ValueNodePtr v_node = std::make_shared<mindspore::ValueNode>(MakeValue(values));
185 AbstractBasePtrList abs_list;
186 (void)std::transform(values.cbegin(), values.cend(), std::back_inserter(abs_list), [](const int64_t &item) {
187 return std::make_shared<mindspore::abstract::AbstractScalar>(int64_t(item));
188 });
189 auto abs_tuple = std::make_shared<mindspore::abstract::AbstractTuple>(abs_list);
190 v_node->set_abstract(abs_tuple);
191 return v_node;
192 }
193
make_int_node(const FuncGraphPtr func_graph,int int_value)194 ValueNodePtr make_int_node(const FuncGraphPtr func_graph, int int_value) {
195 ShapeVector int_shape{1};
196 tensor::TensorPtr int_tensor = std::make_shared<Tensor>(mindspore::kNumberTypeInt32, int_shape);
197 int *tensor_data = reinterpret_cast<int *>(int_tensor->data_c());
198 for (int i = 0; i < int_tensor->data().size(); i++) {
199 tensor_data[i] = int_value;
200 }
201 mindspore::ValueNodePtr int_tensor_node = std::make_shared<mindspore::ValueNode>(int_tensor);
202 int_tensor_node->set_abstract(int_tensor->ToAbstract());
203 func_graph->AddValueNode(int_tensor_node);
204 return int_tensor_node;
205 }
206
make_weight_tensor(TypeId type_id,ShapeVector shape)207 tensor::TensorPtr make_weight_tensor(TypeId type_id, ShapeVector shape) {
208 tensor::TensorPtr weight_tensor = std::make_shared<Tensor>(type_id, shape);
209 std::default_random_engine generator;
210 int max_count = 10000;
211 int tensor_size = SizeToInt(weight_tensor->data().size());
212 if (type_id == kNumberTypeFloat64) {
213 const double mean_64 = 0;
214 const double stddev_64 = 1;
215 std::normal_distribution<double> dist_64(mean_64, stddev_64);
216 double *float_64_data = reinterpret_cast<double *>(weight_tensor->data_c());
217 for (int i = 0; i < std::min(tensor_size, max_count); i++) {
218 double random_float_64 = dist_64(generator);
219 if (random_float_64 > 0) {
220 float_64_data[i] = random_float_64;
221 }
222 }
223 } else {
224 MS_LOG(DEBUG) << "Type id is: " << type_id << ", weights will be float_32 format.";
225 const float mean = 0;
226 const float stddev = 1;
227 std::normal_distribution<float> dist_32(mean, stddev);
228 float *float_32_data = reinterpret_cast<float *>(weight_tensor->data_c());
229 for (int i = 0; i < std::min(tensor_size, max_count); i++) {
230 float random_float_32 = dist_32(generator);
231 if (random_float_32 > 0) {
232 float_32_data[i] = random_float_32;
233 }
234 }
235 }
236 return weight_tensor;
237 }
238
CheckIfObfuscated(const FuncGraphPtr & func_graph)239 bool CheckIfObfuscated(const FuncGraphPtr &func_graph) {
240 MS_EXCEPTION_IF_NULL(func_graph);
241 auto mgr = Manage(func_graph);
242 MS_EXCEPTION_IF_NULL(mgr);
243 auto all_nodes = mgr->all_nodes();
244 for (AnfNodePtr node : all_nodes) {
245 MS_EXCEPTION_IF_NULL(node);
246 std::string node_name = node->fullname_with_scope();
247 if (node_name.find("Switch") != std::string::npos) {
248 return true;
249 }
250 }
251 return false;
252 }
253
ObfuscateMindIR(const FuncGraphPtr & func_graph)254 FuncGraphPtr DynamicObfuscator::ObfuscateMindIR(const FuncGraphPtr &func_graph) {
255 MS_LOG(INFO) << "Start obfuscation.";
256 MS_EXCEPTION_IF_NULL(func_graph);
257 if (CheckIfObfuscated(func_graph)) {
258 MS_EXCEPTION(ValueError) << "The input model has been onfuscated, do not obfuscate it again.";
259 }
260 auto mgr = Manage(func_graph);
261 MS_EXCEPTION_IF_NULL(mgr);
262 auto all_nodes = mgr->all_nodes();
263 for (auto item : all_nodes) {
264 auto abs = item->abstract();
265 if (abs != nullptr) {
266 item->set_abstract(abs->Broaden());
267 }
268 }
269 int node_nums = SizeToLong(all_nodes.size());
270 MS_LOG(INFO) << "Total node num: " << node_nums;
271
272 // do subgraph fake-branch obfuscation
273 SubGraphFakeBranch(func_graph);
274
275 if (subgraph_obf_num_ == 0) {
276 MS_LOG(WARNING)
277 << "The model has not been obfuscated, which means obf_random_seed or customized_func is not need to set.";
278 }
279 return func_graph;
280 }
281
ObfuscateOpType(const AnfNodePtr & node)282 std::string DynamicObfuscator::ObfuscateOpType(const AnfNodePtr &node) {
283 if (node == nullptr) {
284 MS_LOG(ERROR) << "Input node is nullptr, get name failed!";
285 return "";
286 }
287 if (node->isa<CNode>()) {
288 MS_LOG(INFO) << "The node_name is: " << node->fullname_with_scope();
289 std::string op_name = get_node_prim_name(node);
290 std::vector<std::string> target_op_list;
291 target_op_list.insert(target_op_list.end(), single_input_target_op_.begin(), single_input_target_op_.end());
292 target_op_list.insert(target_op_list.end(), single_input_with_weight_target_op_.begin(),
293 single_input_with_weight_target_op_.end());
294
295 auto found = std::find_if(target_op_list.cbegin(), target_op_list.cend(),
296 [&](const auto &target_name) { return op_name == target_name; });
297 if (found != target_op_list.cend()) {
298 return *found;
299 }
300 }
301 return "";
302 }
303
ObfuscateOpCase(const std::string obf_type)304 ObfCase DynamicObfuscator::ObfuscateOpCase(const std::string obf_type) {
305 if (obf_type.empty()) {
306 MS_LOG(ERROR) << "Obf_type is empty string.";
307 return ObfCase::NotObfNode;
308 }
309 auto name_equal = [&obf_type](const std::string &s) { return s == obf_type; };
310 if (std::any_of(single_input_target_op_.begin(), single_input_target_op_.end(), name_equal)) {
311 return ObfCase::OneInputNoWeightNode;
312 } else if (std::any_of(single_input_with_weight_target_op_.begin(), single_input_with_weight_target_op_.end(),
313 name_equal)) {
314 return ObfCase::OneInputWithWeightNode;
315 } else {
316 return ObfCase::NotObfNode;
317 }
318 }
319
RandomSeedModeControl(const FuncGraphPtr func_graph)320 CNodePtr DynamicObfuscator::RandomSeedModeControl(const FuncGraphPtr func_graph) {
321 ShapeVector y_shape{1};
322 tensor::TensorPtr y_tensor = std::make_shared<Tensor>(mindspore::kNumberTypeInt32, y_shape);
323 if (!has_build_appended_input) {
324 MS_LOG(INFO) << "Build parameter y_append.";
325 auto y_append = AddObfuscatedParam(func_graph);
326 y_append->set_name("y_append");
327 y_append->set_abstract(y_tensor->ToAbstract());
328 has_build_appended_input = true;
329 }
330 auto y_append = func_graph->GetParameterByName("y_append");
331
332 if (used_control_node_ == 0) {
333 // make equal function node
334 ValueNodePtr equal_v_node = std::make_shared<mindspore::ValueNode>(mindspore::prim::kPrimEqual);
335 func_graph->AddValueNode(equal_v_node);
336 ValueNodePtr equal_compa_node = make_int_node(func_graph, branch_control_input_);
337 CNodePtr equal_c_node = func_graph->NewCNode({equal_v_node, y_append, equal_compa_node});
338 if (equal_c_node == nullptr) {
339 MS_LOG(ERROR) << "equal_c_node is nullptr.";
340 return nullptr;
341 }
342 tensor::TensorPtr equal_tensor = std::make_shared<Tensor>(mindspore::kNumberTypeBool, y_shape);
343 equal_c_node->set_abstract(equal_tensor->ToAbstract());
344 func_graph->AddNode(equal_c_node);
345 used_control_node_ += 1;
346 switch_branch_ = true;
347 return equal_c_node;
348 }
349 // make greater function node
350 int comparison_int = rand();
351 ValueNodePtr greater_v_node = std::make_shared<mindspore::ValueNode>(mindspore::prim::kPrimGreater);
352 func_graph->AddValueNode(greater_v_node);
353 ValueNodePtr greater_compa_node = make_int_node(func_graph, comparison_int);
354 CNodePtr greater_c_node = func_graph->NewCNode({greater_v_node, y_append, greater_compa_node});
355 if (greater_c_node == nullptr) {
356 MS_LOG(ERROR) << "greater_c_node is nullptr.";
357 return nullptr;
358 }
359 tensor::TensorPtr greater_tensor = std::make_shared<Tensor>(mindspore::kNumberTypeBool, y_shape);
360 greater_c_node->set_abstract(greater_tensor->ToAbstract());
361 func_graph->AddNode(greater_c_node);
362 used_control_node_ += 1;
363 switch_branch_ = branch_control_input_ > comparison_int;
364 return greater_c_node;
365 }
366
CreateScalarValue(const FuncGraphPtr & func_graph,int64_t value)367 ValueNodePtr CreateScalarValue(const FuncGraphPtr &func_graph, int64_t value) {
368 auto scalar_value = MakeValue(value);
369 auto scalar_node = NewValueNode(scalar_value);
370 scalar_node->set_abstract(scalar_value->ToAbstract());
371 func_graph->AddValueNode(scalar_node);
372 return scalar_node;
373 }
374
add_stride_slice_node(FuncGraphPtr func_graph,ShapeVector begin_vector,ShapeVector stride_vector,ShapeVector end_vector,int end_mask,int begin_mask,mindspore::CNodePtr prev_node)375 mindspore::CNodePtr add_stride_slice_node(FuncGraphPtr func_graph, ShapeVector begin_vector, ShapeVector stride_vector,
376 ShapeVector end_vector, int end_mask, int begin_mask,
377 mindspore::CNodePtr prev_node) {
378 mindspore::ValueNodePtr begin_v_node = build_tuple_value_node(begin_vector);
379 mindspore::ValueNodePtr stride_v_node = build_tuple_value_node(stride_vector);
380 mindspore::ValueNodePtr end_v_node = build_tuple_value_node(end_vector);
381 auto begin_mask_node = CreateScalarValue(func_graph, begin_mask);
382 MS_EXCEPTION_IF_NULL(begin_mask_node);
383 auto end_mask_node = CreateScalarValue(func_graph, end_mask);
384 MS_EXCEPTION_IF_NULL(end_mask_node);
385 auto ellipsis_mask_node = CreateScalarValue(func_graph, int64_t(0));
386 MS_EXCEPTION_IF_NULL(ellipsis_mask_node);
387 auto new_axis_mask_node = CreateScalarValue(func_graph, int64_t(0));
388 MS_EXCEPTION_IF_NULL(new_axis_mask_node);
389 auto shrink_axis_mask_node = CreateScalarValue(func_graph, int64_t(1));
390 MS_EXCEPTION_IF_NULL(shrink_axis_mask_node);
391 func_graph->AddValueNode(begin_v_node);
392 func_graph->AddValueNode(stride_v_node);
393 func_graph->AddValueNode(end_v_node);
394 mindspore::PrimitivePtr slice_prim = mindspore::prim::kPrimStridedSlice;
395 slice_prim->set_attr("is_load", MakeValue(true));
396 mindspore::ValueNodePtr slice_v_node = std::make_shared<mindspore::ValueNode>(slice_prim);
397 func_graph->AddValueNode(slice_v_node);
398 mindspore::CNodePtr slice_c_node =
399 func_graph->NewCNode({slice_v_node, prev_node, begin_v_node, end_v_node, stride_v_node, begin_mask_node,
400 end_mask_node, ellipsis_mask_node, new_axis_mask_node, shrink_axis_mask_node});
401 return slice_c_node;
402 }
403
CustomOpModeControl(const FuncGraphPtr func_graph,const AnfNodePtr & prev_node) const404 CNodePtr DynamicObfuscator::CustomOpModeControl(const FuncGraphPtr func_graph, const AnfNodePtr &prev_node) const {
405 mindspore::PrimitivePtr reshape_prim = mindspore::prim::kPrimReshape;
406 reshape_prim->set_attr("is_load", MakeValue(true));
407 mindspore::ValueNodePtr reshape_v_node = std::make_shared<mindspore::ValueNode>(reshape_prim);
408 func_graph->AddValueNode(reshape_v_node);
409 ShapeVector prev_node_shape = get_node_shape(prev_node);
410 int shape_multiply = std::accumulate(prev_node_shape.cbegin(), prev_node_shape.cend(), 1, std::multiplies<int>());
411 MS_LOG(INFO) << "The shape_multiply is: " << shape_multiply;
412
413 ShapeVector flat_shape{1, shape_multiply};
414 mindspore::ValueNodePtr shape_v_node = std::make_shared<mindspore::ValueNode>(MakeValue(flat_shape));
415 func_graph->AddValueNode(shape_v_node);
416 mindspore::CNodePtr reshape_c_node = func_graph->NewCNode({reshape_v_node, prev_node, shape_v_node});
417 TypeId data_type = get_node_dtype(prev_node);
418 auto reshape_abstract = std::make_shared<Tensor>(data_type, flat_shape)->ToAbstract();
419 reshape_c_node->set_abstract(reshape_abstract);
420 func_graph->AddNode(reshape_c_node);
421
422 // the first stride_slice x[0]
423 ShapeVector begin_1{0, 0};
424 ShapeVector stride_1{1, 1};
425 mindspore::CNodePtr slice_c_node_1 =
426 add_stride_slice_node(func_graph, begin_1, stride_1, flat_shape, 2, 2, reshape_c_node);
427 ShapeVector slice_1_shape{shape_multiply};
428 slice_c_node_1->set_abstract(std::make_shared<Tensor>(data_type, slice_1_shape)->ToAbstract());
429 func_graph->AddNode(slice_c_node_1);
430
431 // the first stride_slice x[0][0]
432 ShapeVector begin_2{0};
433 ShapeVector end_2{1};
434 ShapeVector stride_2{1};
435 mindspore::CNodePtr slice_c_node_2 =
436 add_stride_slice_node(func_graph, begin_2, stride_2, stride_2, 0, 0, slice_c_node_1);
437 ShapeVector slice_2_shape{1};
438 slice_c_node_2->set_abstract(std::make_shared<Tensor>(data_type, slice_2_shape)->ToAbstract());
439 func_graph->AddNode(slice_c_node_2);
440
441 // the second stride_slice x[0][1]
442 ShapeVector begin_3{1};
443 ShapeVector end_3{1};
444 ShapeVector stride_3{2};
445 mindspore::CNodePtr slice_c_node_3 =
446 add_stride_slice_node(func_graph, begin_3, stride_3, stride_3, 0, 0, slice_c_node_1);
447 ShapeVector slice_3_shape{1};
448 slice_c_node_3->set_abstract(std::make_shared<Tensor>(data_type, slice_3_shape)->ToAbstract());
449 func_graph->AddNode(slice_c_node_3);
450
451 // add opaque predicate
452 PrimitivePtr custom_prim = mindspore::prim::kPrimOpaquePredicate;
453 custom_prim->set_attr("is_load", MakeValue(true));
454 std::vector<ValuePtr> input_names_value;
455 input_names_value.push_back(std::make_shared<StringImm>("x"));
456 input_names_value.push_back(std::make_shared<StringImm>("y"));
457 custom_prim->set_attr(mindspore::kAttrInputNames, std::make_shared<ValueList>(input_names_value));
458 std::vector<ValuePtr> output_names_value;
459 output_names_value.push_back(std::make_shared<StringImm>("output"));
460 custom_prim->set_attr(mindspore::kAttrOutputNames, std::make_shared<ValueList>(output_names_value));
461 auto opaque_v_node = std::make_shared<mindspore::ValueNode>(custom_prim);
462 func_graph->AddValueNode(opaque_v_node);
463 auto opaque_c_node = func_graph->NewCNode({opaque_v_node, slice_c_node_2, slice_c_node_3});
464 ShapeVector y_shape{1};
465 auto bool_tensor = std::make_shared<Tensor>(mindspore::kNumberTypeBool, y_shape);
466 opaque_c_node->set_abstract(bool_tensor->ToAbstract());
467 func_graph->AddNode(opaque_c_node);
468 return opaque_c_node;
469 }
470
GetControlNode(const FuncGraphPtr & func_graph,const AnfNodePtr & prev_node)471 CNodePtr DynamicObfuscator::GetControlNode(const FuncGraphPtr &func_graph, const AnfNodePtr &prev_node) {
472 MS_EXCEPTION_IF_NULL(func_graph);
473 MS_EXCEPTION_IF_NULL(prev_node);
474 if (branch_control_input_ != 0) {
475 MS_LOG(INFO) << "Run password mode.";
476 return RandomSeedModeControl(func_graph);
477 }
478 MS_LOG(INFO) << "Run customized function mode.";
479 if (prev_node != nullptr && prev_node->abstract() != nullptr) {
480 return CustomOpModeControl(func_graph, prev_node);
481 }
482 return nullptr;
483 }
484
get_random_prim(const std::string & obf_type,const mindspore::CNodePtr & node)485 mindspore::PrimitivePtr DynamicObfuscator::get_random_prim(const std::string &obf_type,
486 const mindspore::CNodePtr &node) {
487 std::vector<string> split_words = name_split(obf_type, "-");
488 if (split_words.empty()) {
489 MS_LOG(WARNING) << "obf_type is empty.";
490 return nullptr;
491 }
492 std::string prim_name_ori = split_words[0];
493 mindspore::PrimitivePtr poolptr = nullptr;
494 if (prim_name_ori == kMaxPoolOpName || prim_name_ori == kAvgPoolOpName) {
495 if (prim_name_ori == kMaxPoolOpName) {
496 poolptr = std::make_shared<Primitive>("AvgPool");
497 } else {
498 poolptr = std::make_shared<Primitive>("MaxPool");
499 }
500 auto primitive = GetCNodePrimitive(node);
501 MS_EXCEPTION_IF_NULL(primitive);
502 MS_EXCEPTION_IF_NULL(primitive->GetAttr("input_names"));
503 MS_EXCEPTION_IF_NULL(primitive->GetAttr("output_names"));
504 MS_EXCEPTION_IF_NULL(primitive->GetAttr("format"));
505 MS_EXCEPTION_IF_NULL(primitive->GetAttr("kernel_size"));
506 MS_EXCEPTION_IF_NULL(primitive->GetAttr("strides"));
507 poolptr->set_attr("input_names", primitive->GetAttr("input_names"));
508 poolptr->set_attr("output_names", primitive->GetAttr("output_names"));
509 poolptr->set_attr("format", primitive->GetAttr("format"));
510 poolptr->set_attr("pad_mode", primitive->GetAttr("pad_mode"));
511 poolptr->set_attr("kernel_size", primitive->GetAttr("kernel_size"));
512 poolptr->set_attr("strides", primitive->GetAttr("strides"));
513 return poolptr;
514 }
515 mindspore::PrimitivePtr prim_node = one_input_prim_[0];
516 do {
517 int random = rand() % SizeToInt(one_input_prim_.size());
518 prim_node = one_input_prim_[random];
519 } while (prim_name_ori == prim_node->ToString());
520 return prim_node;
521 }
522
UpdateDict(const AnfNodePtr & node,const bool isParent)523 void DynamicObfuscator::UpdateDict(const AnfNodePtr &node, const bool isParent) {
524 if (node == nullptr) {
525 MS_LOG(ERROR) << "Input node is nullptr, update dict failed.";
526 return;
527 }
528 MS_LOG(INFO) << "Update: " << node->fullname_with_scope() << " to dict.";
529 if (isParent) {
530 parent_names_.push(node->fullname_with_scope());
531 } else {
532 node_names_.push(node->fullname_with_scope());
533 subgraph_obf_num_++;
534 }
535 node_dict_[node->fullname_with_scope()] = node->cast<mindspore::AnfNodePtr>();
536 if (node_dict_[node->fullname_with_scope()] == nullptr) {
537 MS_LOG(ERROR) << "Update node " << node->fullname_with_scope() << " failed.";
538 }
539 }
540
CheckDuplicatedParent(const AnfNodePtr & node)541 void DynamicObfuscator::CheckDuplicatedParent(const AnfNodePtr &node) {
542 if (node == nullptr) {
543 MS_LOG(ERROR) << "Input node is nullptr, check parent failed.";
544 return;
545 }
546 if (node_dict_.find(node->fullname_with_scope()) != node_dict_.cend()) {
547 while (node_names_.top() != "-") {
548 node_dict_.erase(node_names_.top());
549 node_names_.pop();
550 subgraph_obf_num_--;
551 }
552 } else {
553 node_names_.push("-");
554 UpdateDict(node, true);
555 if (branch_control_input_ == 0) {
556 bool customized_func_result = mindspore::kernel::CustomizedOpaquePredicate::GetInstance().run_function(
557 static_cast<float>(1), static_cast<float>(1));
558 customized_func_results_.push_back(customized_func_result);
559 }
560 }
561 }
562
IsTarget(const std::string & cnode_name)563 bool DynamicObfuscator::IsTarget(const std::string &cnode_name) {
564 if (cnode_name.empty()) {
565 MS_LOG(INFO) << "CNode name is empty.";
566 return false;
567 }
568 std::vector<std::string> target_op_list;
569 target_op_list.insert(target_op_list.end(), single_input_target_op_.begin(), single_input_target_op_.end());
570 target_op_list.insert(target_op_list.end(), single_input_with_weight_target_op_.begin(),
571 single_input_with_weight_target_op_.end());
572 if (std::find(target_op_list.cbegin(), target_op_list.cend(), cnode_name) != target_op_list.cend()) {
573 return true;
574 }
575 return false;
576 }
577
CheckInputNodes(const mindspore::CNodePtr & node)578 mindspore::CNodePtr DynamicObfuscator::CheckInputNodes(const mindspore::CNodePtr &node) {
579 if (node == nullptr) {
580 MS_LOG(ERROR) << "Input node is nullptr, check input failed.";
581 return nullptr;
582 }
583 auto node_inputs = node->inputs();
584 for (auto input_node : node_inputs) {
585 std::string cnode_name = get_node_prim_name(input_node);
586 if (IsTarget(cnode_name)) {
587 return input_node->cast<mindspore::CNodePtr>();
588 }
589 }
590 return nullptr;
591 }
592
BuildOneInputNoWeightNode(const FuncGraphPtr & fg,const mindspore::AnfNodePtr & input_node,const mindspore::PrimitivePtr prim_node) const593 mindspore::CNodePtr DynamicObfuscator::BuildOneInputNoWeightNode(const FuncGraphPtr &fg,
594 const mindspore::AnfNodePtr &input_node,
595 const mindspore::PrimitivePtr prim_node) const {
596 if (input_node == nullptr) {
597 MS_LOG(ERROR) << "Build Node failed: input node is nullptr.";
598 return nullptr;
599 }
600 if (fg == nullptr) {
601 MS_LOG(ERROR) << "Build Node failed: FuncGraph is nullptr.";
602 return nullptr;
603 }
604 if (prim_node == nullptr) {
605 MS_LOG(ERROR) << "Build Node failed: prim_node is nullptr.";
606 return nullptr;
607 }
608 std::vector<ValuePtr> input_names_value;
609 input_names_value.emplace_back(std::make_shared<StringImm>("x"));
610 prim_node->set_attr("is_load", MakeValue(true));
611 prim_node->set_attr(mindspore::kAttrInputNames, std::make_shared<ValueList>(input_names_value));
612 mindspore::ValueNodePtr v_node = std::make_shared<mindspore::ValueNode>(prim_node);
613 fg->AddValueNode(v_node);
614 mindspore::CNodePtr c_node = fg->NewCNode({v_node, input_node});
615 if (c_node == nullptr) {
616 MS_LOG(ERROR) << "Build node failed: cnode is nullptr.";
617 return nullptr;
618 }
619 ShapeVector x_shape = get_node_shape(input_node);
620 TypeId type_id = get_node_dtype(input_node);
621 auto node_abstract = std::make_shared<Tensor>(type_id, x_shape)->ToAbstract();
622 if (node_abstract == nullptr) {
623 MS_LOG(ERROR) << "Build node failed: node abstract is nullptr.";
624 return nullptr;
625 }
626 c_node->set_abstract(node_abstract);
627 fg->AddNode(c_node);
628 return c_node;
629 }
630
BuildOneInputWithWeightNode(const FuncGraphPtr & fg,const mindspore::AnfNodePtr & input_node,const mindspore::CNodePtr & node,const mindspore::AnfNodePtr & weights) const631 mindspore::CNodePtr DynamicObfuscator::BuildOneInputWithWeightNode(const FuncGraphPtr &fg,
632 const mindspore::AnfNodePtr &input_node,
633 const mindspore::CNodePtr &node,
634 const mindspore::AnfNodePtr &weights) const {
635 if (node == nullptr) {
636 MS_LOG(ERROR) << "Build one input with weight node failed: node is nullptr.";
637 return nullptr;
638 }
639 std::string node_name = node->fullname_with_scope();
640 if (input_node == nullptr) {
641 MS_LOG(ERROR) << "Build " << node_name << " failed: input node is nullptr.";
642 return nullptr;
643 }
644 if (fg == nullptr) {
645 MS_LOG(ERROR) << "Build " << node_name << " failed: FuncGraph is nullptr.";
646 return nullptr;
647 }
648 if (weights == nullptr) {
649 MS_LOG(ERROR) << "Build " << node_name << " failed: weights is nullptr.";
650 return nullptr;
651 }
652 std::vector<AnfNodePtr> node_inputs = node->inputs();
653 if (node_inputs.size() < 1) {
654 MS_LOG(ERROR) << "Build " << node_name << " failed: inputs size is 0";
655 return nullptr;
656 }
657 mindspore::ValueNodePtr v_node = node_inputs[0]->cast<mindspore::ValueNodePtr>();
658 fg->AddValueNode(v_node);
659
660 mindspore::CNodePtr c_node = fg->NewCNode({v_node, input_node, weights});
661 if (c_node == nullptr) {
662 MS_LOG(ERROR) << "Build " << node_name << " failed: cnode is nullptr.";
663 return nullptr;
664 }
665 ShapeVector x_shape = get_node_shape(node);
666 TypeId type_id = get_node_dtype(node);
667 auto node_abstract = std::make_shared<Tensor>(type_id, x_shape)->ToAbstract();
668 if (node_abstract == nullptr) {
669 MS_LOG(ERROR) << "Build " << node_name << " failed: abstract is nullptr.";
670 return nullptr;
671 }
672 c_node->set_abstract(node_abstract);
673 (void)fg->AddNode(c_node);
674 return c_node;
675 }
676
CloneSubGraph(const std::vector<mindspore::CNodePtr> & node_arr,const mindspore::AnfNodePtr & parent_node)677 FuncGraphPtr DynamicObfuscator::CloneSubGraph(const std::vector<mindspore::CNodePtr> &node_arr,
678 const mindspore::AnfNodePtr &parent_node) {
679 MS_LOG(INFO) << "Building Clone Graph ";
680 mindspore::FuncGraphPtr fg_clone = std::make_shared<FuncGraph>();
681 ShapeVector x_shape = get_node_shape(parent_node);
682 TypeId x_type_id = get_node_dtype(parent_node);
683 MS_LOG(INFO) << "Get Shape Input X";
684
685 mindspore::ParameterPtr input_x = fg_clone->add_parameter();
686 if (input_x == nullptr) {
687 MS_LOG(ERROR) << "Build clone graph failed: input_x is nullptr.";
688 return nullptr;
689 }
690 input_x->set_name("input_x_clone");
691 tensor::TensorPtr input_x_tensor = std::make_shared<Tensor>(x_type_id, x_shape);
692 input_x->set_abstract(input_x_tensor->ToAbstract());
693 mindspore::AnfNodePtr last_node = input_x;
694 for (auto node : node_arr) {
695 std::string obf_type = ObfuscateOpType(node);
696 MS_LOG(INFO) << "obf_type: " << obf_type;
697 mindspore::ObfCase obf_case = ObfuscateOpCase(obf_type);
698 switch (obf_case) {
699 case ObfCase::OneInputNoWeightNode: {
700 mindspore::PrimitivePtr prim_node = GetCNodePrimitive(node);
701 last_node = BuildOneInputNoWeightNode(fg_clone, last_node, prim_node);
702 if (last_node == nullptr) {
703 MS_LOG(ERROR) << "Last node after build is nullptr.";
704 return nullptr;
705 }
706 break;
707 }
708 case ObfCase::OneInputWithWeightNode: {
709 mindspore::ParameterPtr weight_param = fg_clone->add_parameter();
710 if (weight_param == nullptr) {
711 MS_LOG(ERROR) << "Build OneInputWithWeightNode failed: weights is nullptr.";
712 return nullptr;
713 }
714 weight_param->set_name("OneInputWithWeightNode_clone");
715 last_node = BuildOneInputWithWeightNode(fg_clone, last_node, node, weight_param);
716 if (last_node == nullptr) {
717 MS_LOG(ERROR) << "Last node after build is nullptr.";
718 return nullptr;
719 }
720 break;
721 }
722 case ObfCase::NotObfNode: {
723 MS_LOG(ERROR) << "The current node does not belong to target nodes.";
724 }
725 default:
726 return nullptr;
727 }
728 }
729
730 mindspore::ValueNodePtr return_v = std::make_shared<mindspore::ValueNode>(mindspore::prim::kPrimReturn);
731 fg_clone->AddValueNode(return_v);
732 mindspore::CNodePtr return_c_node = fg_clone->NewCNode({return_v, last_node});
733 if (return_c_node == nullptr) {
734 MS_LOG(ERROR) << "Build return failed: return cnode is nullptr.";
735 return nullptr;
736 }
737 ShapeVector return_shape = get_node_shape(last_node->cast<mindspore::CNodePtr>());
738 TypeId type_id = get_node_dtype(last_node->cast<mindspore::CNodePtr>());
739 auto return_abstract = std::make_shared<Tensor>(type_id, return_shape)->ToAbstract();
740 if (return_abstract == nullptr) {
741 MS_LOG(ERROR) << "Build return failed: return abstract is nullptr.";
742 return nullptr;
743 }
744 return_c_node->set_abstract(return_abstract);
745 fg_clone->AddNode(return_c_node);
746 fg_clone->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, true);
747 fg_clone->set_return(return_c_node);
748 return fg_clone;
749 }
750
BuildFakeGraph(const std::vector<mindspore::CNodePtr> & node_arr,const mindspore::AnfNodePtr & parent_node)751 FuncGraphPtr DynamicObfuscator::BuildFakeGraph(const std::vector<mindspore::CNodePtr> &node_arr,
752 const mindspore::AnfNodePtr &parent_node) {
753 MS_LOG(INFO) << "Building Fake Graph ";
754 mindspore::FuncGraphPtr fg_fake = std::make_shared<FuncGraph>();
755
756 ShapeVector x_shape = get_node_shape(parent_node);
757 TypeId x_type_id = get_node_dtype(parent_node);
758 mindspore::ParameterPtr input_x = fg_fake->add_parameter();
759 if (input_x == nullptr) {
760 MS_LOG(ERROR) << "Build fake graph failed: input_x is nullptr.";
761 return nullptr;
762 }
763 input_x->set_name("input_x_fake");
764 tensor::TensorPtr input_x_tensor = std::make_shared<Tensor>(x_type_id, x_shape);
765 input_x->set_abstract(input_x_tensor->ToAbstract());
766 mindspore::AnfNodePtr last_node = input_x;
767 for (auto node : node_arr) {
768 std::string obf_type = ObfuscateOpType(node);
769 mindspore::ObfCase obf_case = ObfuscateOpCase(obf_type);
770 switch (obf_case) {
771 case ObfCase::OneInputNoWeightNode: {
772 mindspore::PrimitivePtr prim_node = get_random_prim(obf_type, node);
773 last_node = BuildOneInputNoWeightNode(fg_fake, last_node, prim_node);
774 if (last_node == nullptr) {
775 MS_LOG(ERROR) << "Last node after build is nullptr.";
776 return nullptr;
777 }
778 break;
779 }
780 case ObfCase::OneInputWithWeightNode: {
781 mindspore::AnfNodePtr ori_vnode = node->cast<mindspore::CNodePtr>()->inputs()[2];
782 TypeId type_id = get_node_dtype(ori_vnode);
783 ShapeVector shape = get_node_shape(ori_vnode);
784 tensor::TensorPtr weight_tensor = make_weight_tensor(type_id, shape);
785 mindspore::ValueNodePtr weight_vnode = std::make_shared<mindspore::ValueNode>(weight_tensor);
786 if (weight_vnode == nullptr) {
787 MS_LOG(ERROR) << "Build OneInputWithWeightNode failed: value node is nullptr.";
788 return nullptr;
789 }
790 weight_vnode->set_abstract(weight_tensor->ToAbstract());
791 fg_fake->AddValueNode(weight_vnode);
792 last_node = BuildOneInputWithWeightNode(fg_fake, last_node, node, weight_vnode);
793 if (last_node == nullptr) {
794 MS_LOG(ERROR) << "Last node after build is nullptr.";
795 return nullptr;
796 }
797 break;
798 }
799 case ObfCase::NotObfNode: {
800 MS_LOG(ERROR) << "The current node is not obf-target";
801 }
802 default:
803 return nullptr;
804 }
805 }
806
807 mindspore::ValueNodePtr return_v = std::make_shared<mindspore::ValueNode>(mindspore::prim::kPrimReturn);
808 fg_fake->AddValueNode(return_v);
809 mindspore::CNodePtr return_c_node = fg_fake->NewCNode({return_v, last_node});
810 if (return_c_node == nullptr) {
811 MS_LOG(ERROR) << "Build return failed: return cnode is nullptr.";
812 return nullptr;
813 }
814 ShapeVector return_shape = get_node_shape(last_node->cast<mindspore::CNodePtr>());
815 TypeId type_id = get_node_dtype(last_node->cast<mindspore::CNodePtr>());
816 auto return_abstract = std::make_shared<Tensor>(type_id, return_shape)->ToAbstract();
817 if (return_abstract == nullptr) {
818 MS_LOG(ERROR) << "Build return failed: return abstract is nullptr.";
819 return nullptr;
820 }
821 return_c_node->set_abstract(return_abstract);
822 fg_fake->AddNode(return_c_node);
823 fg_fake->set_return(return_c_node);
824 fg_fake->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, true);
825 return fg_fake;
826 }
827
AddPartialBranch(const FuncGraphPtr fg,FuncGraphPtr fg_sub,const std::vector<mindspore::CNodePtr> & nodes)828 mindspore::CNodePtr DynamicObfuscator::AddPartialBranch(const FuncGraphPtr fg, FuncGraphPtr fg_sub,
829 const std::vector<mindspore::CNodePtr> &nodes) {
830 if (fg == nullptr) {
831 MS_LOG(ERROR) << "Add subgraph failed: fg is null.";
832 return nullptr;
833 }
834 if (fg_sub == nullptr) {
835 MS_LOG(ERROR) << "Add subgraph failed: fg_sub is null.";
836 return nullptr;
837 }
838 if (nodes.size() == 0) {
839 MS_LOG(ERROR) << "Add subgraph failed: input nodes size is 0.";
840 return nullptr;
841 }
842
843 mindspore::ValueNodePtr switch_partial = std::make_shared<mindspore::ValueNode>(mindspore::prim::kPrimPartial);
844 fg->AddValueNode(switch_partial);
845 mindspore::ValueNodePtr fg_subgraph_node = std::make_shared<mindspore::ValueNode>(fg_sub);
846 fg_subgraph_node->set_abstract(fg_sub->ToAbstract());
847 fg->AddValueNode(fg_subgraph_node);
848 std::vector<mindspore::AnfNodePtr> subgraph_inputs = {switch_partial, fg_subgraph_node};
849 if (nodes[0]->size() < kSwitchInputsNum) {
850 MS_LOG(ERROR) << "Add subgraph failed: the input number of node[0] is smaller than " << kSwitchInputsNum;
851 return nullptr;
852 }
853 subgraph_inputs.push_back(nodes[0]->inputs()[1]);
854 size_t func_params_num = fg_sub->parameters().size();
855 size_t pushed_inputs = 1;
856 for (unsigned i = 0; i < nodes.size(); i++) {
857 if (pushed_inputs >= func_params_num) {
858 break;
859 }
860 std::string obf_type = ObfuscateOpType(nodes[i]);
861 if ((obf_type == kConv2DOpName || obf_type == kMatMulOpName) && nodes[i]->size() >= kNodeWithWeightInputsNum) {
862 subgraph_inputs.push_back(nodes[i]->inputs()[kWeightIndex]);
863 pushed_inputs += 1;
864 }
865 }
866 mindspore::CNodePtr switch_partial_c = fg->NewCNode(subgraph_inputs);
867 if (switch_partial_c == nullptr) {
868 MS_LOG(ERROR) << "Add subgraph failed: switch partial is null.";
869 return nullptr;
870 }
871 switch_partial_c->set_abstract(fg_sub->ToAbstract());
872 fg->AddNode(switch_partial_c);
873 return switch_partial_c;
874 }
875
AddSwitchNode(const FuncGraphPtr fg)876 void DynamicObfuscator::AddSwitchNode(const FuncGraphPtr fg) {
877 if (fg == nullptr) {
878 MS_LOG(ERROR) << "Build switch failed: FuncGraph is nullptr.";
879 return;
880 }
881 int switch_num_ = 0;
882 while (!parent_names_.empty()) {
883 auto mgr = mindspore::Manage(fg);
884 if (mgr == nullptr) {
885 MS_LOG(ERROR) << "FuncGraph manager is nullptr.";
886 return;
887 }
888 std::vector<mindspore::CNodePtr> nodes;
889 mindspore::AnfNodePtr last_node = nullptr;
890 mindspore::CNodePtr child_node = nullptr;
891 while (node_names_.top() != "-") {
892 MS_LOG(INFO) << "Processing sub_graph node: " << node_names_.top();
893 last_node = node_dict_[node_names_.top()];
894 nodes.push_back(last_node->cast<mindspore::CNodePtr>());
895 node_names_.pop(); // pop '-'
896 }
897 node_names_.pop();
898 if (mgr->node_users().find(last_node) != mgr->node_users().cend()) {
899 auto users = mgr->node_users()[last_node];
900 child_node = users.cbegin()->first->cast<mindspore::CNodePtr>();
901 } else {
902 MS_LOG(WARNING) << "Child Node of " << last_node->fullname_with_scope() << " is nullptr.";
903 }
904 mindspore::AnfNodePtr parent_node = node_dict_[parent_names_.top()];
905 parent_names_.pop();
906
907 mindspore::FuncGraphPtr fg_subgraph_clone = CloneSubGraph(nodes, parent_node);
908 mindspore::FuncGraphPtr fg_subgraph_fake = BuildFakeGraph(nodes, parent_node);
909
910 mgr->AddFuncGraph(fg_subgraph_clone);
911 mgr->AddFuncGraph(fg_subgraph_fake);
912
913 mindspore::CNodePtr switch_partial_clone_c = AddPartialBranch(fg, fg_subgraph_clone, nodes);
914 mindspore::CNodePtr switch_partial_fake_c = AddPartialBranch(fg, fg_subgraph_fake, nodes);
915 if (switch_partial_clone_c == nullptr || switch_partial_fake_c == nullptr) {
916 continue;
917 }
918
919 CNodePtr control_node = GetControlNode(fg, parent_node);
920 if (control_node == nullptr) {
921 continue;
922 }
923
924 mindspore::ValueNodePtr switch_v_node = std::make_shared<mindspore::ValueNode>(mindspore::prim::kPrimSwitch);
925 fg->AddValueNode(switch_v_node);
926 mindspore::CNodePtr switch_c_node;
927 if (branch_control_input_ == 0) {
928 if (static_cast<int>(customized_func_results_.size()) <= used_control_node_) {
929 MS_LOG(ERROR) << "customized_func_results_ size is smaller than used_control_node_.";
930 }
931 switch_branch_ = customized_func_results_[used_control_node_];
932 used_control_node_ += 1;
933 }
934 if (switch_branch_) {
935 switch_c_node = fg->NewCNode({switch_v_node, control_node, switch_partial_clone_c, switch_partial_fake_c});
936 } else {
937 switch_c_node = fg->NewCNode({switch_v_node, control_node, switch_partial_fake_c, switch_partial_clone_c});
938 }
939 if (switch_c_node == nullptr) {
940 MS_LOG(ERROR) << "switch_c_node is nullptr.";
941 return;
942 }
943 switch_c_node->set_abstract(fg_subgraph_clone->ToAbstract());
944 fg->AddNode(switch_c_node);
945
946 mindspore::CNodePtr call_cnode = fg->NewCNode({switch_c_node});
947 if (call_cnode == nullptr) {
948 MS_LOG(ERROR) << "call_cnode is nullptr.";
949 return;
950 }
951 fg->AddNode(call_cnode);
952
953 if (child_node != nullptr) {
954 unsigned i = 0;
955 for (auto &weak_input : child_node->weak_inputs()) {
956 auto input = weak_input.lock();
957 MS_EXCEPTION_IF_NULL(input);
958 if (input->fullname_with_scope() == last_node->fullname_with_scope()) {
959 child_node->set_input(i, call_cnode);
960 break;
961 }
962 i++;
963 }
964 switch_num_++;
965 }
966 }
967 MS_LOG(WARNING) << switch_num_ << " switch nodes have been added.";
968 used_control_node_ = 0;
969 }
970
GetNodeMaxNum(const AnfNodeSet nodes)971 int GetNodeMaxNum(const AnfNodeSet nodes) {
972 int node_max_num = 0;
973 for (auto node : nodes) {
974 if (node != nullptr && node->isa<CNode>()) {
975 int op_num = get_op_num(node);
976 if (op_num > node_max_num) {
977 node_max_num = op_num;
978 }
979 }
980 }
981 return node_max_num;
982 }
983
NodePrepareCheck(const mindspore::AnfNodePtr & node,const int & branch_control_input)984 bool NodePrepareCheck(const mindspore::AnfNodePtr &node, const int &branch_control_input) {
985 std::string ignore_name = "down_sample_layer";
986 if (node == nullptr) {
987 MS_LOG(INFO) << "Find null node!" << std::endl;
988 return false;
989 }
990 if (!node->isa<CNode>()) {
991 MS_LOG(INFO) << "Not a Cnode." << std::endl;
992 return false;
993 }
994 // Ignore ResNet's down_sample_layer node for customized func mode.
995 if ((branch_control_input == 0) && (node->fullname_with_scope().find(ignore_name) != std::string::npos)) {
996 MS_LOG(INFO) << "Find down_sample_layer node: " << node->fullname_with_scope() << std::endl;
997 return false;
998 }
999 return true;
1000 }
1001
IsValidOpNum(const int & current_num,const int & compa_num) const1002 bool DynamicObfuscator::IsValidOpNum(const int ¤t_num, const int &compa_num) const {
1003 if (branch_control_input_ != 0) {
1004 return true;
1005 }
1006 return current_num <= compa_num;
1007 }
1008
SubGraphFakeBranch(const FuncGraphPtr func_graph)1009 void DynamicObfuscator::SubGraphFakeBranch(const FuncGraphPtr func_graph) {
1010 if (func_graph == nullptr) {
1011 MS_LOG(ERROR) << "Build fake sub-graph failed: FuncGraph is nullptr.";
1012 return;
1013 }
1014 node_names_.push("-");
1015 auto mgr = mindspore::Manage(func_graph);
1016 if (mgr == nullptr) {
1017 MS_LOG(ERROR) << "Manager is null node!";
1018 return;
1019 }
1020 auto all_nodes = mgr->all_nodes();
1021 int node_nums = SizeToInt(all_nodes.size());
1022 int obfuscate_target_num = std::ceil(node_nums * obf_ratio_ / keyExpandRate);
1023 int op_num = GetNodeMaxNum(all_nodes);
1024 MS_LOG(INFO) << "Init op_num is: " << op_num;
1025 std::vector<mindspore::AnfNodePtr> sorted_nodes;
1026 for (auto node : all_nodes) {
1027 MS_LOG(INFO) << "The last node name is: " << node->fullname_with_scope();
1028 sorted_nodes = TopoSort(node); // the node number in front of sorted nodes is the smallest
1029 break;
1030 }
1031 std::reverse(sorted_nodes.begin(), sorted_nodes.end());
1032 for (auto node : sorted_nodes) {
1033 if (!NodePrepareCheck(node, branch_control_input_)) {
1034 continue;
1035 }
1036 std::string cnode_name = get_node_prim_name(node);
1037 MS_LOG(INFO) << "CNode name is: " << cnode_name;
1038 int cur_op_num = get_op_num(node);
1039 float dropout_rate = 0.1;
1040 int dropout_rand = rand() % static_cast<int>(1.0 / dropout_rate);
1041 if (IsTarget(cnode_name) && IsValidOpNum(cur_op_num, op_num) && dropout_rand != 0 &&
1042 (node_dict_.find(node->fullname_with_scope()) == node_dict_.cend())) {
1043 UpdateDict(node, false);
1044 op_num = cur_op_num;
1045 bool stop_traverse = false;
1046 mindspore::CNodePtr curr_cnode = node->cast<mindspore::CNodePtr>();
1047 while (!stop_traverse) {
1048 mindspore::CNodePtr valid_input = CheckInputNodes(curr_cnode);
1049 dropout_rand = rand() % static_cast<int>(1.0 / dropout_rate);
1050 if (valid_input && dropout_rand != 0 &&
1051 (node_dict_.find(valid_input->fullname_with_scope()) == node_dict_.cend())) {
1052 UpdateDict(valid_input, false);
1053 op_num = get_op_num(valid_input);
1054 curr_cnode = valid_input;
1055 } else {
1056 stop_traverse = true;
1057 if (curr_cnode->size() > 1) {
1058 CheckDuplicatedParent(curr_cnode->inputs()[1]);
1059 }
1060 }
1061 }
1062 }
1063 if (subgraph_obf_num_ >= obfuscate_target_num) {
1064 break;
1065 }
1066 }
1067 node_names_.pop();
1068 if (branch_control_input_ == 0) {
1069 mindspore::kernel::CustomizedOpaquePredicate::GetInstance().init_calling_count();
1070 }
1071 AddSwitchNode(func_graph);
1072 MS_LOG(WARNING) << subgraph_obf_num_ << " nodes have been obfuscated.";
1073 }
1074 } // namespace mindspore
1075