• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1py::object ${func_name}_Base(const PrimitivePtr &prim, const py::list &args) {
2  #ifndef ENABLE_TEST
3    MS_LOG(DEBUG) << "Run ${func_name} start";
4    auto op_run_info = PyNativeAlgo::PyBoost::Init(prim, args);
5    op_run_info->signatures = ops::${op_def_name}.signatures_;
6    static Converter converter(&ops::${op_def_name});
7    converter.Parse(args);
8    ${parser_body}
9
10    static auto top_type = PredictOutType(op_run_info);
11    auto node = stub::MakeTopNode(top_type);
12    GilReleaseWithCheck release_gil;
13    op_run_info->stub_output = node.second;
14    op_run_info->source_type = converter.source_type();
15    DispatchOp(
16      std::make_shared<FrontendTask>(
17        [${op_args}](const FrontendOpRunInfoPtr &op_run_info) {
18          MS_LOG(DEBUG) << "Run frontend task ${func_name} start";
19          auto old_stream_id = kernel::pyboost::PyBoostUtils::cur_stream_id();
20          kernel::pyboost::PyBoostUtils::set_cur_stream_id(op_run_info->base_op_run_info.stream_id);
21
22          // stub tensor to tensor.
23          ${convert_stub}
24
25          // Create op
26          auto op = CREATE_PYBOOST_OP(${op_name}, op_run_info->base_op_run_info.device_target);
27
28          // Do mixed precision and implicit cast
29          static const std::vector<std::vector<size_t>> same_type_table{${same_type}};
30          auto [${cast_args}] = PyNativeAlgo::PyBoost::SetPyBoostCastForInputs<${type_num}>(op_run_info, same_type_table, ${call_args});
31
32          // Run op
33          (void)op->Call(${cast_args});
34          ${optional_to_value}
35
36          // Data sync in mix mode(Graph and PyNative)
37          PyNativeAlgo::PyBoost::DataSyncForGraph(op, {${grad_args}});
38
39          // Update op and op_run_info by op outputs
40          PyNativeAlgo::PyBoost::UpdateOpRunInfo(op, op_run_info);
41
42          // Do auto grad
43          if (op_run_info->requires_grad) {
44            PyNativeAlgo::PyBoost::DoGrad(op, op_run_info, {${grad_args}});
45          }
46          kernel::pyboost::PyBoostUtils::set_cur_stream_id(old_stream_id);
47
48          MS_LOG(DEBUG) << "Run frontend task ${func_name} end";
49        },
50        op_run_info
51      )
52    );
53    MS_LOG(DEBUG) << "Run ${func_name} end";
54    return node.first;
55  #else
56    return PyNativeAlgo::PyBoost::RunPyFunction(prim, args);
57  #endif
58}
59
60py::object ${func_name}(const py::args &args) {
61  if (args.size() != kIndex2) {
62    MS_LOG(EXCEPTION) << "Two args are needed by RunOp"
63                      << ", but got " << args.size();
64  }
65  const auto &prim = PyNativeAlgo::PyBoost::ConvertPrimitive(args[0]);
66  runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kPynative, runtime::ProfilerEvent::kRunOp,
67                                     prim->name(), false, true);
68  return ${func_name}_Base(prim, args[1]);
69}
70
71class ${class_name}PrimAdapter: public PrimitiveFunctionAdapter {
72  public:
73   ${class_name}PrimAdapter() : PrimitiveFunctionAdapter() {}
74   ~${class_name}PrimAdapter() = default;
75   std::string name() override { return "${class_name}"; }
76   py::object Call(const py::args &args) {
77     runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kPynative, runtime::ProfilerEvent::kRunOp,
78                                        "${class_name}", false, true);
79     return ${func_name}_Base(prim::kPrim${class_name}, args);
80   }
81};
82