1# 2# Copyright (c) 2023 Apple Inc. All rights reserved. 3# Provided subject to the LICENSE file in the top level directory. 4# 5 6import random 7 8import executorch.exir.control_flow as control_flow 9import torch 10from functorch.experimental.control_flow import cond 11from torch import nn 12 13 14class LayerNormModule(torch.nn.Module): 15 def __init__(self) -> None: 16 super().__init__() 17 self.norm = torch.nn.LayerNorm([5, 10, 10]) 18 19 def forward(self, arg): 20 return self.norm(arg) 21 22 @staticmethod 23 def get_example_inputs(): 24 return (torch.randn(20, 5, 10, 10),) 25 26 27class Conv2DModule(torch.nn.Module): 28 def __init__(self) -> None: 29 super().__init__() 30 self.conv = torch.nn.Conv2d(1, 3, 3, stride=1) 31 32 def forward(self, arg): 33 return self.conv(arg) 34 35 @staticmethod 36 def get_example_inputs(): 37 return (torch.randn(1, 1, 3, 3),) 38 39 40class ModuleBasic(torch.nn.Module): 41 def __init__(self): 42 super(ModuleBasic, self).__init__() 43 44 def forward(self, x): 45 return torch.sin(x).max() 46 47 def get_random_inputs(self): 48 return (torch.randn(100),) 49 50 51class ModuleOpsReturnMulti(torch.nn.Module): 52 def __init__(self): 53 super(ModuleOpsReturnMulti, self).__init__() 54 55 def forward(self, a, b): 56 x, y = torch.topk(a, 3) 57 return x * 2 + b 58 59 def get_random_inputs(self): 60 return (torch.randn(10), torch.randn(3)) 61 62 63class ModuleAdd(torch.nn.Module): 64 def __init__(self): 65 super(ModuleAdd, self).__init__() 66 67 def forward(self, x, y): 68 return torch.add(x, y) 69 70 def get_random_inputs(self): 71 return (torch.randn(2, 2), torch.randn(2, 2)) 72 73 74class ModuleFloatAddWithAlpha(torch.nn.Module): 75 def __init__(self): 76 super(ModuleFloatAddWithAlpha, self).__init__() 77 78 def forward(self, x: torch.Tensor, y: torch.Tensor, c: float): 79 return torch.add(x, y, alpha=c) 80 81 def get_random_inputs(self): 82 return (torch.randn(2, 2), torch.randn(2, 2), random.random()) 83 84 85class ModuleIntAddWithAlpha(torch.nn.Module): 86 def __init__(self): 87 super(ModuleIntAddWithAlpha, self).__init__() 88 89 def forward(self, x: torch.Tensor, y: torch.Tensor, c: int): 90 return torch.add(x, y, alpha=c) 91 92 def get_random_inputs(self): 93 return ( 94 torch.randint(0, 10, (2, 2)), 95 torch.randint(0, 10, (2, 2)), 96 random.randint(0, 10), 97 ) 98 99 100class ModuleContainers(torch.nn.Module): 101 def __init__(self): 102 super(ModuleContainers, self).__init__() 103 104 def forward(self, d): 105 a = d["a"] 106 b = d["b"] 107 return {"inputs": (a, b), "c": torch.add(a, b)} 108 109 def get_random_inputs(self): 110 return ({"a": torch.randn(2, 2), "b": torch.randn(2, 2)},) 111 112 113class ToyModelForMemPlanning(torch.nn.Module): 114 def __init__(self): 115 super(ToyModelForMemPlanning, self).__init__() 116 117 def forward(self, a, b): 118 o = a 119 for _ in range(3): 120 o = o * a 121 o = o + b 122 return o 123 124 def get_random_inputs(self): 125 return ( 126 torch.randn(10), 127 torch.randn(10), 128 ) 129 130 131class MemPlanningWithScratchTensor(torch.nn.Module): 132 def __init__(self): 133 super(MemPlanningWithScratchTensor, self).__init__() 134 self.linear1 = torch.nn.Linear(4, 2) 135 self.linear2 = torch.nn.Linear(4, 2) 136 137 def forward(self, a, b): 138 o1 = self.linear1(a) 139 o2 = self.linear2(b) 140 return o1 + o2 141 142 def get_random_inputs(self): 143 return ( 144 torch.randn(10, 4), 145 torch.randn(10, 4), 146 ) 147 148 149class ModuleOpsReturnTensorList(torch.nn.Module): 150 def __init__(self): 151 super(ModuleOpsReturnTensorList, self).__init__() 152 153 def forward(self, x): 154 split = torch.ops.aten.tensor_split.sections(x, 3) 155 return split[0] 156 157 def get_random_inputs(self): 158 return (torch.randn(100),) 159 160 161class ModuleReturnInput(torch.nn.Module): 162 def __init__(self): 163 super(ModuleReturnInput, self).__init__() 164 165 def forward(self, x): 166 return (x, x, {"x": x, "y": x}, [x, x, x]) 167 168 def get_random_inputs(self): 169 return (torch.randn(1),) 170 171 172class ModuleIfElse(torch.nn.Module): 173 def __init__(self): 174 super().__init__() 175 176 def forward(self, c, x): 177 x = x * x 178 179 def addloop(x, n): 180 out = x 181 for _ in range(n - 1): 182 out = out + x 183 return out 184 185 def true_branch(c, x): 186 return addloop(x, 3) 187 188 def false_branch(c, x): 189 return addloop(x, 4) 190 191 y = cond(c, true_branch, false_branch, (c, x)) 192 return y * y 193 194 def get_random_inputs(self): 195 return (torch.randint(2, [1]) == 0, torch.randn(10)) 196 197 198class ModuleIfElseWithBoolInput(torch.nn.Module): 199 def __init__(self): 200 super().__init__() 201 202 def forward(self, c: bool, x: torch.Tensor): 203 x = x * x 204 205 def addloop(x, n): 206 out = x 207 for _ in range(n - 1): 208 out = out + x 209 return out 210 211 def true_branch(c, x): 212 return addloop(x, 3) 213 214 def false_branch(c, x): 215 return addloop(x, 4) 216 217 y = cond(c, true_branch, false_branch, (c, x)) 218 219 return y * y 220 221 def get_random_inputs(self): 222 return (random.randint(0, 1) == 0, torch.randn(10)) 223 224 225class ModuleWhileIf(torch.nn.Module): 226 def __init__(self): 227 super().__init__() 228 229 def forward(self, accum, cnt): 230 @control_flow.tracing_context( 231 inputs=(torch.zeros([1]).to(dtype=torch.long), torch.randint(10, 100, [1])) 232 ) 233 def loop_cond(accum, cnt): 234 return cnt != torch.zeros([1]).to(dtype=torch.long) 235 236 @control_flow.tracing_context( 237 inputs=(torch.zeros([1]).to(dtype=torch.long), torch.randint(10, 100, [1])) 238 ) 239 def loop_body(accum, cnt): 240 # return accum + cnt, cnt - torch.ones([1]).to(dtype=torch.long) 241 @control_flow.tracing_context( 242 inputs=(torch.zeros([1]).to(dtype=torch.long),) 243 ) 244 def true_branch(cnt): 245 return cnt 246 247 @control_flow.tracing_context( 248 inputs=(torch.zeros([1]).to(dtype=torch.long),) 249 ) 250 def false_branch(cnt): 251 return torch.zeros([1], dtype=torch.long) 252 253 accum = accum + cond( 254 torch.BoolTensor([True]), true_branch, false_branch, (cnt,) 255 ) 256 # 'cnt - 1' does not work yet since the runtime does not expect 257 # tensor to be mixed with scalar for sub op. 258 return accum, cnt - torch.ones([1]).to(dtype=torch.long) 259 260 y, _ = control_flow.while_loop( 261 loop_cond, 262 loop_body, 263 (accum, cnt), 264 ) 265 return y 266 267 def get_random_inputs(self): 268 return (torch.zeros([1]).to(dtype=torch.long), torch.randint(10, 100, [1])) 269 270 271class ModuleIfWhile(torch.nn.Module): 272 def __init__(self): 273 super().__init__() 274 275 def forward(self, accum, cnt): 276 @control_flow.tracing_context( 277 inputs=(torch.zeros([1]).to(dtype=torch.long), torch.randint(10, 100, [1])) 278 ) 279 def true_branch(accum, cnt): 280 @control_flow.tracing_context( 281 inputs=( 282 torch.zeros([1]).to(dtype=torch.long), 283 torch.randint(10, 100, [1]), 284 ) 285 ) 286 def loop_cond(accum, cnt): 287 return cnt != torch.zeros([1]).to(dtype=torch.long) 288 289 @control_flow.tracing_context( 290 inputs=( 291 torch.zeros([1]).to(dtype=torch.long), 292 torch.randint(10, 100, [1]), 293 ) 294 ) 295 def loop_body(accum, cnt): 296 return accum + cnt, cnt - torch.ones([1]).to(dtype=torch.long) 297 298 return control_flow.while_loop(loop_cond, loop_body, (accum, cnt)) 299 300 @control_flow.tracing_context( 301 inputs=(torch.zeros([1]).to(dtype=torch.long), torch.randint(10, 100, [1])) 302 ) 303 def false_branch(accum, cnt): 304 return accum, cnt 305 306 return cond(torch.BoolTensor([True]), true_branch, false_branch, (accum, cnt))[ 307 0 308 ] 309 310 def get_random_inputs(self): 311 return (torch.zeros([1]).to(dtype=torch.long), torch.randint(10, 100, [1])) 312 313 314class ModuleContiguousTensor(torch.nn.Module): 315 def __init__(self): 316 super().__init__() 317 self.linear = nn.Linear(8, 32) 318 319 def forward(self, arg): 320 return self.linear(arg) 321 322 def get_random_inputs(self): 323 return (torch.randn(3, 8),) 324 325 326class ModuleInputDynamicShape(torch.nn.Module): 327 def __init__(self): 328 super().__init__() 329 330 def forward(self, x): 331 for _ in range(4): 332 x = x + x 333 x = x * x 334 return x 335 336 def get_upper_bound_inputs(self): 337 return (torch.randn(10),) 338 339 def get_random_inputs(self): 340 n = random.randint(1, 10) 341 return (torch.randn(n),) 342 343 344class ModuleIntermediateDynamicShape(torch.nn.Module): 345 def __init__(self): 346 super().__init__() 347 348 def forward(self, x): 349 x = x * x 350 351 # We should use x[torch.nonzero(x)] ideally, but index op is not supported 352 # in the runtime so far. 353 x = torch.nonzero(x) 354 return x + x 355 356 def get_random_inputs(self): 357 return (torch.randint(0, 2, (10,), dtype=torch.float),) 358 359 360MPS_MODEL_NAME_TO_MODEL = { 361 "conv2D": lambda: (Conv2DModule(), Conv2DModule.get_example_inputs()), 362 "norm": lambda: (LayerNormModule(), LayerNormModule.get_example_inputs()), 363 "module_basic": lambda: (ModuleBasic(), ModuleBasic().get_random_inputs()), 364 "module_ops_return_multi": lambda: ( 365 ModuleOpsReturnMulti(), 366 ModuleOpsReturnMulti().get_random_inputs(), 367 ), 368 "module_add": lambda: (ModuleAdd(), ModuleAdd().get_random_inputs()), 369 "module_float_add_with_alpha": lambda: ( 370 ModuleFloatAddWithAlpha(), 371 ModuleFloatAddWithAlpha().get_random_inputs(), 372 ), 373 "module_int_add_with_alpha": lambda: ( 374 ModuleIntAddWithAlpha(), 375 ModuleIntAddWithAlpha().get_random_inputs(), 376 ), 377 "module_containers": lambda: ( 378 ModuleContainers(), 379 ModuleContainers().get_random_inputs(), 380 ), 381 "toy_model_for_mem_planning": lambda: ( 382 ToyModelForMemPlanning(), 383 ToyModelForMemPlanning().get_random_inputs(), 384 ), 385 "mem_planning_with_scratch_tensor": lambda: ( 386 MemPlanningWithScratchTensor(), 387 MemPlanningWithScratchTensor().get_random_inputs(), 388 ), 389 "module_ops_return_tensor_list": lambda: ( 390 ModuleOpsReturnTensorList(), 391 ModuleOpsReturnTensorList().get_random_inputs(), 392 ), 393 "module_return_input": lambda: ( 394 ModuleReturnInput(), 395 ModuleReturnInput().get_random_inputs(), 396 ), 397 "module_if_else": lambda: (ModuleIfElse(), ModuleIfElse().get_random_inputs()), 398 "module_if_else_with_bool_input": lambda: ( 399 ModuleIfElseWithBoolInput(), 400 ModuleIfElseWithBoolInput().get_random_inputs(), 401 ), 402 "module_while_if": lambda: (ModuleWhileIf(), ModuleWhileIf().get_random_inputs()), 403 "module_if_while": lambda: (ModuleIfWhile(), ModuleIfWhile().get_random_inputs()), 404 "module_contiguous_tensor": lambda: ( 405 ModuleContiguousTensor(), 406 ModuleContiguousTensor().get_random_inputs(), 407 ), 408 "module_input_dynamic_shape": lambda: ( 409 ModuleInputDynamicShape(), 410 ModuleInputDynamicShape().get_random_inputs(), 411 ), 412 "module_intermediate_dynamic_shape": lambda: ( 413 ModuleIntermediateDynamicShape(), 414 ModuleIntermediateDynamicShape().get_random_inputs(), 415 ), 416} 417