• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1mindspore.train.RunContext
2==========================
3
4.. py:class:: mindspore.train.RunContext(original_args)
5
6    保存和管理模型的相关信息。
7
8    `RunContext` 主要用于收集训练或推理过程中模型的上下文相关信息并作为入参传入callback对象中来实现信息的共享。
9
10    Callback的类方法中,调用 `RunContext.original_args()` 可以获取模型当前的上下文信息,用户也可以为此信息添加额外的自定义属性,同时 `request_stop()` 方法可以控制训练过程的停止。具体用法请查看 `回调机制Callback <https://www.mindspore.cn/tutorials/zh-CN/master/advanced/model/callback.html>`_。
11
12    `RunContext.original_args()` 存储的模型信息为一个字典型变量,在训练和推理过程会存储不同的属性。详情如下:
13
14    +--------------------------+------------------------+---------------------------------------+
15    |   训练过程支持的属性     |   推理过程支持的属性   |               说明                    |
16    +==========================+========================+=======================================+
17    |   train_network          |                        |    包含了优化器和损失的训练网络       |
18    +--------------------------+------------------------+---------------------------------------+
19    |   epoch_num              |                        |      训练的epoch数                    |
20    +--------------------------+------------------------+---------------------------------------+
21    |  train_dataset           |                        |         训练集                        |
22    +--------------------------+------------------------+---------------------------------------+
23    |   loss_fn                |                        |         损失函数                      |
24    +--------------------------+------------------------+---------------------------------------+
25    |   optimizer              |                        |         优化器                        |
26    +--------------------------+------------------------+---------------------------------------+
27    |  parallel_mode           |                        |         并行模式                      |
28    +--------------------------+------------------------+---------------------------------------+
29    |   device_number          |                        |         设备编号                      |
30    +--------------------------+------------------------+---------------------------------------+
31    |   train_dataset_element  |                        |         当前step的训练数据            |
32    +--------------------------+------------------------+---------------------------------------+
33    |  last_save_ckpt_step     |                        |      最后一次存储ckpt的step           |
34    +--------------------------+------------------------+---------------------------------------+
35    |  latest_ckpt_file        |                        |            ckpt文件名                 |
36    +--------------------------+------------------------+---------------------------------------+
37    |   cur_epoch_num          |                        |          当前的epoch                  |
38    +--------------------------+------------------------+---------------------------------------+
39    |                          |  eval_network          |          评估网络                     |
40    +--------------------------+------------------------+---------------------------------------+
41    |                          |  valid_dataset         |          验证集                       |
42    +--------------------------+------------------------+---------------------------------------+
43    |                          |   metrics              |          评估指标                     |
44    +--------------------------+------------------------+---------------------------------------+
45    |   mode                   |   mode                 |        "train"或"eval"模式            |
46    +--------------------------+------------------------+---------------------------------------+
47    |  batch_num               |   batch_num            |        训练或推理的batch数            |
48    +--------------------------+------------------------+---------------------------------------+
49    |   list_callback          |   list_callback        |        回调列表                       |
50    +--------------------------+------------------------+---------------------------------------+
51    |   network                |    network             |       基础的网络结构                  |
52    +--------------------------+------------------------+---------------------------------------+
53    |  cur_step_num            |    cur_step_num        |       当前的训练或推理的step          |
54    +--------------------------+------------------------+---------------------------------------+
55    |   dataset_sink_mode      |    dataset_sink_mode   |       训练或推理的数据是否下沉        |
56    +--------------------------+------------------------+---------------------------------------+
57    |   net_outputs            |      net_outputs       |       训练或推理的网络输出            |
58    +--------------------------+------------------------+---------------------------------------+
59
60    参数:
61        - **original_args** (dict) - 模型的相关信息。
62
63    .. py:method:: get_stop_requested()
64
65        获取是否停止训练的标志。
66
67        返回:
68            bool,如果为True,则 `Model.train()` 停止迭代。
69
70    .. py:method:: original_args()
71
72        获取模型相关信息的对象。
73
74        返回:
75            dict,含有模型的相关信息的对象。
76
77        教程样例:
78            - `回调机制 Callback - 自定义回调机制
79              <https://mindspore.cn/tutorials/zh-CN/master/advanced/model/callback.html#自定义回调机制>`_
80
81    .. py:method:: request_stop()
82
83        在训练期间设置停止请求。
84
85        可以使用此函数请求停止训练。 `Model.train()` 会检查是否调用此函数。
86
87        教程样例:
88            - `回调机制 Callback - 自定义终止训练
89              <https://mindspore.cn/tutorials/zh-CN/master/advanced/model/callback.html#自定义终止训练>`_
90