Lines Matching full:torch
4 import torch
5 from torch import Tensor
6 from torch.ao.quantization.fake_quantize import (
10 from torch.ao.quantization.observer import (
16 from torch.ao.quantization.quantizer import DerivedQuantizationSpec, QuantizationSpec
17 from torch.fx import Node
39 (broadcast_act_scale, broadcast_weight_scale) = torch.broadcast_tensors(
42 derived_scale = (broadcast_act_scale * broadcast_weight_scale).to(torch.float32)
43 derived_zero = torch.zeros(derived_scale.size()).to(torch.int32)
54 dtype=torch.int32,
55 quant_min=torch.iinfo(torch.int32).min,
56 quant_max=torch.iinfo(torch.int32).max,
58 qscheme=torch.per_channel_symmetric,
68 dtype=torch.uint8,
70 torch.per_tensor_symmetric if act_symmetric else torch.per_tensor_affine
77 dtype=torch.int8,
78 quant_min=torch.iinfo(torch.int8).min + 1,
79 quant_max=torch.iinfo(torch.int8).max,
80 qscheme=torch.per_tensor_symmetric,
86 dtype=torch.int32,
87 quant_min=torch.iinfo(torch.int32).min,
88 quant_max=torch.iinfo(torch.int32).max,
89 qscheme=torch.per_tensor_symmetric,
109 dtype=torch.int32,
110 quant_min=torch.iinfo(torch.uint16).min,
111 quant_max=torch.iinfo(torch.uint16).max,
112 qscheme=torch.per_tensor_affine,
117 dtype=torch.int8,
120 qscheme=torch.per_tensor_symmetric,
126 dtype=torch.int32,
127 quant_min=torch.iinfo(torch.int32).min,
128 quant_max=torch.iinfo(torch.int32).max,
129 qscheme=torch.per_tensor_symmetric,
148 dtype=torch.int32,
149 quant_min=torch.iinfo(torch.uint16).min,
150 quant_max=torch.iinfo(torch.uint16).max,
151 qscheme=torch.per_tensor_affine,
156 dtype=torch.uint8,
157 qscheme=torch.per_tensor_symmetric,
163 dtype=torch.int32,
164 quant_min=torch.iinfo(torch.int32).min,
165 quant_max=torch.iinfo(torch.int32).max,
166 qscheme=torch.per_tensor_symmetric,
185 dtype=torch.int32,
186 quant_min=torch.iinfo(torch.uint16).min,
187 quant_max=torch.iinfo(torch.uint16).max,
188 qscheme=torch.per_tensor_affine,
193 dtype=torch.int16,
194 quant_min=torch.iinfo(torch.int16).min + 1,
195 quant_max=torch.iinfo(torch.int16).max,
196 qscheme=torch.per_tensor_symmetric,
201 # torch does not support uint16 quantization, use int32 to bypass
203 dtype=torch.int32,
204 quant_min=torch.iinfo(torch.int32).min,
205 quant_max=torch.iinfo(torch.int32).max,
206 qscheme=torch.per_tensor_symmetric,
221 act_dtype=torch.uint8,
222 weight_dtype=torch.int8,
228 torch.uint8,
229 torch.uint16,
230 torch.int8,
231 torch.int16,
233 # TODO accept "int4" temporally. Remove "int4" when torch support torch.int4 dtype
234 supported_weight_dtypes = {"int4", torch.int8, torch.int16}
243 # torch do not support uint16 quantization, use int32 to bypass
245 dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype,
246 quant_min=torch.iinfo(act_dtype).min,
247 quant_max=torch.iinfo(act_dtype).max,
248 qscheme=torch.per_tensor_affine,
253 dtype=torch.int8 if weight_dtype == "int4" else weight_dtype,
254 quant_min=-7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).min + 1,
255 quant_max=7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).max,
256 qscheme=torch.per_channel_symmetric,
278 dtype=torch.uint8,
280 torch.per_tensor_symmetric if act_symmetric else torch.per_tensor_affine
286 dtype=torch.uint8,
288 torch.per_tensor_symmetric if act_symmetric else torch.per_tensor_affine
295 dtype=torch.int8,
296 quant_min=torch.iinfo(torch.int8).min + 1,
297 quant_max=torch.iinfo(torch.int8).max,
298 qscheme=torch.per_tensor_symmetric,
303 dtype=torch.int8,
304 quant_min=torch.iinfo(torch.int8).min + 1,
305 quant_max=torch.iinfo(torch.int8).max,
306 qscheme=torch.per_tensor_symmetric,
312 dtype=torch.int32,
313 quant_min=torch.iinfo(torch.int32).min,
314 quant_max=torch.iinfo(torch.int32).max,
315 qscheme=torch.per_tensor_symmetric,
320 dtype=torch.int32,
321 quant_min=torch.iinfo(torch.int32).min,
322 quant_max=torch.iinfo(torch.int32).max,
323 qscheme=torch.per_tensor_symmetric,
341 dtype=torch.int32,
342 quant_min=torch.iinfo(torch.uint16).min,
343 quant_max=torch.iinfo(torch.uint16).max,
344 qscheme=torch.per_tensor_affine,
349 dtype=torch.int32,
350 quant_min=torch.iinfo(torch.uint16).min,
351 quant_max=torch.iinfo(torch.uint16).max,
352 qscheme=torch.per_tensor_affine,
357 dtype=torch.int8,
360 qscheme=torch.per_tensor_symmetric,
366 dtype=torch.int8,
369 qscheme=torch.per_tensor_symmetric,
375 dtype=torch.int32,
376 quant_min=torch.iinfo(torch.int32).min,
377 quant_max=torch.iinfo(torch.int32).max,
378 qscheme=torch.per_tensor_symmetric,
383 dtype=torch.int32,
384 quant_min=torch.iinfo(torch.int32).min,
385 quant_max=torch.iinfo(torch.int32).max,
386 qscheme=torch.per_tensor_symmetric,
401 act_dtype=torch.uint8,
402 weight_dtype=torch.int8,
406 torch.uint8,
407 torch.uint16,
408 torch.int8,
409 torch.int16,
411 # TODO accept "int4" temporally. Remove "int4" when torch support torch.int4 dtype
412 supported_weight_dtypes = {"int4", torch.int8, torch.int16}
421 # torch do not support uint16 quantization, use int32 to bypass
423 dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype,
424 quant_min=torch.iinfo(act_dtype).min,
425 quant_max=torch.iinfo(act_dtype).max,
426 qscheme=torch.per_tensor_affine,
431 dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype,
432 quant_min=torch.iinfo(act_dtype).min,
433 quant_max=torch.iinfo(act_dtype).max,
434 qscheme=torch.per_tensor_affine,
439 dtype=torch.int8 if weight_dtype == "int4" else weight_dtype,
440 quant_min=-7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).min + 1,
441 quant_max=7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).max,
442 qscheme=torch.per_channel_symmetric,
447 dtype=torch.int8 if weight_dtype == "int4" else weight_dtype,
448 quant_min=-7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).min + 1,
449 quant_max=7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).max,
450 qscheme=torch.per_channel_symmetric,