diff --git a/gptqmodel/looper/native_processor.py b/gptqmodel/looper/native_processor.py index 388ba66cf..01741e709 100644 --- a/gptqmodel/looper/native_processor.py +++ b/gptqmodel/looper/native_processor.py @@ -23,7 +23,7 @@ from ..looper.named_module import NamedModule from ..models import BaseGPTQModel from ..quantization.config import QuantizeConfig -from ..quantization.gptq import CPU, DEVICE_1 +from ..quantization.gptq import CPU, DEVICE_1, DEVICE_2 from ..utils.logger import setup_logger log = setup_logger() @@ -75,7 +75,7 @@ def tmp(module, inp: Tuple[torch.Tensor, ...], out: torch.Tensor): inp = inp[0].detach() if self.qcfg.v2_memory_device == "auto": - v2_memory_device = DEVICE_1 + v2_memory_device = DEVICE_2 elif self.qcfg.v2_memory_device == "cpu": # slower but >= 4x vram memory reduction v2_memory_device = CPU @@ -84,7 +84,7 @@ def tmp(module, inp: Tuple[torch.Tensor, ...], out: torch.Tensor): elif isinstance(self.qcfg.v2_memory_device, torch.device): v2_memory_device = self.qcfg.v2_memory_device else: - v2_memory_device = DEVICE_1 + v2_memory_device = DEVICE_2 self.native_inp_caches[name] += [inp.to(device=v2_memory_device)] del inp, out diff --git a/gptqmodel/quantization/gptq.py b/gptqmodel/quantization/gptq.py index 8cd20c08d..2cf76a79c 100644 --- a/gptqmodel/quantization/gptq.py +++ b/gptqmodel/quantization/gptq.py @@ -44,6 +44,7 @@ DEVICE_0 = auto_select_torch_device(index=0) # device_1 may be same as device_0 if there is only 1 visible/active device DEVICE_1 = auto_select_torch_device(index=1) +DEVICE_2 = auto_select_torch_device(index=2) lock = threading.Lock() diff --git a/gptqmodel/quantization/gptqv2.py b/gptqmodel/quantization/gptqv2.py index 48790737b..e07e19b7c 100644 --- a/gptqmodel/quantization/gptqv2.py +++ b/gptqmodel/quantization/gptqv2.py @@ -29,9 +29,11 @@ from ..looper.named_module import NamedModule from ..quantization import QuantizeConfig +from ..utils.logger import setup_logger from ..utils.torch import torch_compile, torch_sync from .gptq import DEVICE_1, GPTQ +log = setup_logger() class GPTQv2(GPTQ): def __init__(self, module: NamedModule, qcfg: Optional[QuantizeConfig]=None): @@ -72,9 +74,46 @@ def __init__(self, module: NamedModule, qcfg: Optional[QuantizeConfig]=None): # self.dXXT.addmm_((native_inp.T-reshaped_inp.T), reshaped_inp, beta=beta, alpha=alpha) # del native_inp, reshaped_inp + def find_closest_native_input(self, inp): + if not self.native_inps: + return None + + # only match with exact same shape + shape_matches = [] + for i, native_inp in enumerate(self.native_inps): + if native_inp.shape == inp.shape: + shape_matches.append((i, native_inp)) + + # then find the closest tensor value match + if shape_matches: + closest_idx = -1 + min_diff = float('inf') + for i, native_inp in shape_matches: + native_inp = native_inp.to(device=inp.device) + diff = (native_inp - inp).abs().sum().item() + if diff < min_diff: + min_diff = diff + closest_idx = i + if closest_idx != -1: + return self.native_inps.pop(closest_idx) + + # no match found + return None + def process_batch(self, inp): inp = inp.to(device=DEVICE_1, dtype=torch.float32) - native_inp = self.native_inps.pop(0).to(device=DEVICE_1, dtype=torch.float32) + + # not compatible with Moe + # native_inp = self.native_inps.pop(0).to(device=DEVICE_1, dtype=torch.float32) + + native_inp = self.find_closest_native_input(inp) + + if native_inp is None: + log.error(f"Skipping input of shape `{inp.shape}` as it not matched to native_inputs. If this is MoE model, this is safe to ignore.") + return + + native_inp = native_inp.to(device=inp.device) + if len(inp.shape) == 2: inp = inp.unsqueeze(0) native_inp = native_inp.unsqueeze(0) diff --git a/gptqmodel/utils/torch.py b/gptqmodel/utils/torch.py index a5b51272f..5f076a4eb 100644 --- a/gptqmodel/utils/torch.py +++ b/gptqmodel/utils/torch.py @@ -137,12 +137,12 @@ def auto_select_torch_device(index: int = 0): if HAS_CUDA: # defensive check if index > 0 and torch.cuda.device_count() <= index : - index = 0 + index = torch.cuda.device_count() - 1 device = torch.device(f"cuda:{index}") elif HAS_XPU: # defensive check if index > 0 and torch.xpu.device_count() <= index: - index = 0 + index = torch.xpu.device_count() - 1 device = torch.device(f"xpu:{index}") elif HAS_MPS: device = torch.device("mps") # mps has no index