1py::object ${func_name}_Base(const PrimitivePtr &prim, const py::list &args) { 2 #ifndef ENABLE_TEST 3 MS_LOG(DEBUG) << "Run ${func_name} start"; 4 auto op_run_info = PyNativeAlgo::PyBoost::Init(prim, args); 5 op_run_info->signatures = ops::${op_def_name}.signatures_; 6 static Converter converter(&ops::${op_def_name}); 7 converter.Parse(args); 8 ${parser_body} 9 10 static auto top_type = PredictOutType(op_run_info); 11 auto node = stub::MakeTopNode(top_type); 12 GilReleaseWithCheck release_gil; 13 op_run_info->stub_output = node.second; 14 op_run_info->source_type = converter.source_type(); 15 DispatchOp( 16 std::make_shared<FrontendTask>( 17 [${op_args}](const FrontendOpRunInfoPtr &op_run_info) { 18 MS_LOG(DEBUG) << "Run frontend task ${func_name} start"; 19 auto old_stream_id = kernel::pyboost::PyBoostUtils::cur_stream_id(); 20 kernel::pyboost::PyBoostUtils::set_cur_stream_id(op_run_info->base_op_run_info.stream_id); 21 22 // stub tensor to tensor. 23 ${convert_stub} 24 25 // Create op 26 auto op = CREATE_PYBOOST_OP(${op_name}, op_run_info->base_op_run_info.device_target); 27 28 // Do mixed precision and implicit cast 29 static const std::vector<std::vector<size_t>> same_type_table{${same_type}}; 30 auto [${cast_args}] = PyNativeAlgo::PyBoost::SetPyBoostCastForInputs<${type_num}>(op_run_info, same_type_table, ${call_args}); 31 32 // Run op 33 (void)op->Call(${cast_args}); 34 ${optional_to_value} 35 36 // Data sync in mix mode(Graph and PyNative) 37 PyNativeAlgo::PyBoost::DataSyncForGraph(op, {${grad_args}}); 38 39 // Update op and op_run_info by op outputs 40 PyNativeAlgo::PyBoost::UpdateOpRunInfo(op, op_run_info); 41 42 // Do auto grad 43 if (op_run_info->requires_grad) { 44 PyNativeAlgo::PyBoost::DoGrad(op, op_run_info, {${grad_args}}); 45 } 46 kernel::pyboost::PyBoostUtils::set_cur_stream_id(old_stream_id); 47 48 MS_LOG(DEBUG) << "Run frontend task ${func_name} end"; 49 }, 50 op_run_info 51 ) 52 ); 53 MS_LOG(DEBUG) << "Run ${func_name} end"; 54 return node.first; 55 #else 56 return PyNativeAlgo::PyBoost::RunPyFunction(prim, args); 57 #endif 58} 59 60py::object ${func_name}(const py::args &args) { 61 if (args.size() != kIndex2) { 62 MS_LOG(EXCEPTION) << "Two args are needed by RunOp" 63 << ", but got " << args.size(); 64 } 65 const auto &prim = PyNativeAlgo::PyBoost::ConvertPrimitive(args[0]); 66 runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kPynative, runtime::ProfilerEvent::kRunOp, 67 prim->name(), false, true); 68 return ${func_name}_Base(prim, args[1]); 69} 70 71class ${class_name}PrimAdapter: public PrimitiveFunctionAdapter { 72 public: 73 ${class_name}PrimAdapter() : PrimitiveFunctionAdapter() {} 74 ~${class_name}PrimAdapter() = default; 75 std::string name() override { return "${class_name}"; } 76 py::object Call(const py::args &args) { 77 runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kPynative, runtime::ProfilerEvent::kRunOp, 78 "${class_name}", false, true); 79 return ${func_name}_Base(prim::kPrim${class_name}, args); 80 } 81}; 82