• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""Error reproduction utilities for op consistency tests."""
2
3from __future__ import annotations
4
5import difflib
6import pathlib
7import platform
8import sys
9import time
10import traceback
11
12import numpy as np
13
14import onnx
15import onnxruntime as ort
16import onnxscript
17
18import torch
19
20
21_MISMATCH_MARKDOWN_TEMPLATE = """\
22### Summary
23
24The output of ONNX Runtime does not match that of PyTorch when executing test
25`{test_name}`, `sample {sample_num}` in ONNX Script `TorchLib`.
26
27To recreate this report, use
28
29```bash
30CREATE_REPRODUCTION_REPORT=1 python -m pytest onnxscript/tests/function_libs/torch_lib/ops_test.py -k {short_test_name}
31```
32
33### ONNX Model
34
35```
36{onnx_model_text}
37```
38
39### Inputs
40
41Shapes: `{input_shapes}`
42
43<details><summary>Details</summary>
44<p>
45
46```python
47kwargs = {kwargs}
48inputs = {inputs}
49```
50
51</p>
52</details>
53
54### Expected output
55
56Shape: `{expected_shape}`
57
58<details><summary>Details</summary>
59<p>
60
61```python
62expected = {expected}
63```
64
65</p>
66</details>
67
68### Actual output
69
70Shape: `{actual_shape}`
71
72<details><summary>Details</summary>
73<p>
74
75```python
76actual = {actual}
77```
78
79</p>
80</details>
81
82### Difference
83
84<details><summary>Details</summary>
85<p>
86
87```diff
88{diff}
89```
90
91</p>
92</details>
93
94### Full error stack
95
96```
97{error_stack}
98```
99
100### Environment
101
102```
103{sys_info}
104```
105
106"""
107
108
109def create_mismatch_report(
110    test_name: str,
111    sample_num: int,
112    onnx_model: onnx.ModelProto,
113    inputs,
114    kwargs,
115    actual,
116    expected,
117    error: Exception,
118) -> None:
119    torch.set_printoptions(threshold=sys.maxsize)
120
121    error_text = str(error)
122    error_stack = error_text + "\n" + "".join(traceback.format_tb(error.__traceback__))
123    short_test_name = test_name.split(".")[-1]
124    diff = difflib.unified_diff(
125        str(actual).splitlines(),
126        str(expected).splitlines(),
127        fromfile="actual",
128        tofile="expected",
129        lineterm="",
130    )
131    onnx_model_text = onnx.printer.to_text(onnx_model)
132    input_shapes = repr(
133        [
134            f"Tensor<{inp.shape}, dtype={inp.dtype}>"
135            if isinstance(inp, torch.Tensor)
136            else inp
137            for inp in inputs
138        ]
139    )
140    sys_info = f"""\
141OS: {platform.platform()}
142Python version: {sys.version}
143onnx=={onnx.__version__}
144onnxruntime=={ort.__version__}
145onnxscript=={onnxscript.__version__}
146numpy=={np.__version__}
147torch=={torch.__version__}"""
148
149    markdown = _MISMATCH_MARKDOWN_TEMPLATE.format(
150        test_name=test_name,
151        short_test_name=short_test_name,
152        sample_num=sample_num,
153        input_shapes=input_shapes,
154        inputs=inputs,
155        kwargs=kwargs,
156        expected=expected,
157        expected_shape=expected.shape if isinstance(expected, torch.Tensor) else None,
158        actual=actual,
159        actual_shape=actual.shape if isinstance(actual, torch.Tensor) else None,
160        diff="\n".join(diff),
161        error_stack=error_stack,
162        sys_info=sys_info,
163        onnx_model_text=onnx_model_text,
164    )
165
166    markdown_file_name = f'mismatch-{short_test_name.replace("/", "-").replace(":", "-")}-{str(time.time()).replace(".", "_")}.md'
167    markdown_file_path = save_error_report(markdown_file_name, markdown)
168    print(f"Created reproduction report at {markdown_file_path}")
169
170
171def save_error_report(file_name: str, text: str):
172    reports_dir = pathlib.Path("error_reports")
173    reports_dir.mkdir(parents=True, exist_ok=True)
174    file_path = reports_dir / file_name
175    with open(file_path, "w", encoding="utf-8") as f:
176        f.write(text)
177
178    return file_path
179