1"""Error reproduction utilities for op consistency tests.""" 2 3from __future__ import annotations 4 5import difflib 6import pathlib 7import platform 8import sys 9import time 10import traceback 11 12import numpy as np 13 14import onnx 15import onnxruntime as ort 16import onnxscript 17 18import torch 19 20 21_MISMATCH_MARKDOWN_TEMPLATE = """\ 22### Summary 23 24The output of ONNX Runtime does not match that of PyTorch when executing test 25`{test_name}`, `sample {sample_num}` in ONNX Script `TorchLib`. 26 27To recreate this report, use 28 29```bash 30CREATE_REPRODUCTION_REPORT=1 python -m pytest onnxscript/tests/function_libs/torch_lib/ops_test.py -k {short_test_name} 31``` 32 33### ONNX Model 34 35``` 36{onnx_model_text} 37``` 38 39### Inputs 40 41Shapes: `{input_shapes}` 42 43<details><summary>Details</summary> 44<p> 45 46```python 47kwargs = {kwargs} 48inputs = {inputs} 49``` 50 51</p> 52</details> 53 54### Expected output 55 56Shape: `{expected_shape}` 57 58<details><summary>Details</summary> 59<p> 60 61```python 62expected = {expected} 63``` 64 65</p> 66</details> 67 68### Actual output 69 70Shape: `{actual_shape}` 71 72<details><summary>Details</summary> 73<p> 74 75```python 76actual = {actual} 77``` 78 79</p> 80</details> 81 82### Difference 83 84<details><summary>Details</summary> 85<p> 86 87```diff 88{diff} 89``` 90 91</p> 92</details> 93 94### Full error stack 95 96``` 97{error_stack} 98``` 99 100### Environment 101 102``` 103{sys_info} 104``` 105 106""" 107 108 109def create_mismatch_report( 110 test_name: str, 111 sample_num: int, 112 onnx_model: onnx.ModelProto, 113 inputs, 114 kwargs, 115 actual, 116 expected, 117 error: Exception, 118) -> None: 119 torch.set_printoptions(threshold=sys.maxsize) 120 121 error_text = str(error) 122 error_stack = error_text + "\n" + "".join(traceback.format_tb(error.__traceback__)) 123 short_test_name = test_name.split(".")[-1] 124 diff = difflib.unified_diff( 125 str(actual).splitlines(), 126 str(expected).splitlines(), 127 fromfile="actual", 128 tofile="expected", 129 lineterm="", 130 ) 131 onnx_model_text = onnx.printer.to_text(onnx_model) 132 input_shapes = repr( 133 [ 134 f"Tensor<{inp.shape}, dtype={inp.dtype}>" 135 if isinstance(inp, torch.Tensor) 136 else inp 137 for inp in inputs 138 ] 139 ) 140 sys_info = f"""\ 141OS: {platform.platform()} 142Python version: {sys.version} 143onnx=={onnx.__version__} 144onnxruntime=={ort.__version__} 145onnxscript=={onnxscript.__version__} 146numpy=={np.__version__} 147torch=={torch.__version__}""" 148 149 markdown = _MISMATCH_MARKDOWN_TEMPLATE.format( 150 test_name=test_name, 151 short_test_name=short_test_name, 152 sample_num=sample_num, 153 input_shapes=input_shapes, 154 inputs=inputs, 155 kwargs=kwargs, 156 expected=expected, 157 expected_shape=expected.shape if isinstance(expected, torch.Tensor) else None, 158 actual=actual, 159 actual_shape=actual.shape if isinstance(actual, torch.Tensor) else None, 160 diff="\n".join(diff), 161 error_stack=error_stack, 162 sys_info=sys_info, 163 onnx_model_text=onnx_model_text, 164 ) 165 166 markdown_file_name = f'mismatch-{short_test_name.replace("/", "-").replace(":", "-")}-{str(time.time()).replace(".", "_")}.md' 167 markdown_file_path = save_error_report(markdown_file_name, markdown) 168 print(f"Created reproduction report at {markdown_file_path}") 169 170 171def save_error_report(file_name: str, text: str): 172 reports_dir = pathlib.Path("error_reports") 173 reports_dir.mkdir(parents=True, exist_ok=True) 174 file_path = reports_dir / file_name 175 with open(file_path, "w", encoding="utf-8") as f: 176 f.write(text) 177 178 return file_path 179