Skip to content

Commit 5e29275

Browse files
authored
cleanup checks for GIL control, GIL=0, and python >= 3.13.3t (#1743)
* add proper checks for GIL control, GIL=0, and python 3.13.3t Signed-off-by: Qubitium <[email protected]> * ruff Signed-off-by: Qubitium <[email protected]> * ruff fix Signed-off-by: Qubitium <[email protected]> * ruff fix Signed-off-by: Qubitium <[email protected]> * rename Signed-off-by: Qubitium <[email protected]> * recommend users to upgrade to 3.13.3t Signed-off-by: Qubitium <[email protected]> --------- Signed-off-by: Qubitium <[email protected]>
1 parent ffec2bc commit 5e29275

File tree

7 files changed

+55
-32
lines changed

7 files changed

+55
-32
lines changed

gptqmodel/models/auto.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@
9090
from .definitions.gpt_bigcode import GPTBigCodeGPTQ # noqa: E402
9191
from .definitions.gpt_neo import GPTNeoGPTQ # noqa: E402
9292
from .definitions.gpt_neox import GPTNeoXGPTQ # noqa: E402
93+
from .definitions.gpt_oss import GPTOSSGPTQ # noqa: E402
9394
from .definitions.gptj import GPTJGPTQ # noqa: E402
9495
from .definitions.granite import GraniteGPTQ # noqa: E402
9596
from .definitions.grinmoe import GrinMOEGPTQ # noqa: E402
@@ -130,7 +131,6 @@
130131
from .definitions.telechat2 import TeleChat2GPTQ
131132
from .definitions.xverse import XverseGPTQ # noqa: E402
132133
from .definitions.yi import YiGPTQ # noqa: E402
133-
from .definitions.gpt_oss import GPTOSSGPTQ # noqa: E402
134134

135135
# make quants and inference more determinisitc
136136
torch.manual_seed(787)

gptqmodel/models/definitions/gpt_oss.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,14 @@
1313
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
16-
from .._const import EXPERT_INDEX_PLACEHOLDER
17-
from ..base import BaseGPTQModel
1816
import torch
1917
import torch.nn.functional as F
2018
from torch import nn
2119

20+
from .._const import EXPERT_INDEX_PLACEHOLDER
21+
from ..base import BaseGPTQModel
22+
23+
2224
class GptOssExpertsNew(nn.Module):
2325
def __init__(self, config, ori_experts=None):
2426
super().__init__()
@@ -29,7 +31,7 @@ def __init__(self, config, ori_experts=None):
2931
self.alpha = 1.702
3032
self.limit = 7.0
3133
self.quantizing = False
32-
34+
3335
self.gate_up = nn.ModuleList([
3436
nn.Linear(self.hidden_size, 2 * self.expert_dim, dtype=config.dtype)
3537
for _ in range(self.num_experts)
@@ -90,27 +92,27 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig
9092
expert_mask = (router_indices == expert_idx).any(dim=-1) # (num_tokens,)
9193
if not expert_mask.any():
9294
continue
93-
95+
9496
expert_tokens = hidden_states[expert_mask] # (selected_tokens, hidden_size)
95-
97+
9698
gate_up_output = self.gate_up[expert_idx](expert_tokens) # (selected_tokens, 2*expert_dim)
9799
gate, up = gate_up_output[..., ::2], gate_up_output[..., 1::2]
98-
100+
99101
gate = gate.clamp(min=None, max=self.limit)
100102
up = up.clamp(min=-self.limit, max=self.limit)
101103
glu = gate * torch.sigmoid(gate * self.alpha)
102-
104+
103105
expert_output = self.down[expert_idx]((up + 1) * glu) # (selected_tokens, hidden_size)
104-
106+
105107
expert_weights = routing_weights[expert_mask, expert_idx].unsqueeze(-1) # (selected_tokens, 1)
106-
108+
107109
final_output[expert_mask] += expert_output * expert_weights
108-
110+
109111
if seq_len > 1:
110112
final_output = final_output.view(batch_size, seq_len, self.hidden_size)
111113
else:
112114
final_output = final_output.view(batch_size, self.hidden_size)
113-
115+
114116
return final_output
115117

116118
class GptOssTopKRouterNew(nn.Module):
@@ -164,10 +166,11 @@ def after_model_load(self, model, load_quantized_model=False):
164166
return model
165167

166168
import os
167-
from transformers.integrations.hub_kernels import use_kernel_forward_from_hub
168169
from concurrent.futures import ThreadPoolExecutor
169170
from functools import partial
171+
170172
import transformers.models.gpt_oss.modeling_gpt_oss as gpt_oss_modeling
173+
from transformers.integrations.hub_kernels import use_kernel_forward_from_hub
171174

172175
@use_kernel_forward_from_hub("MegaBlocksMoeMLP")
173176
class GptOssMLPNew(nn.Module):
@@ -189,7 +192,7 @@ def process_module(name, module, model, config):
189192
parent, child = name.rsplit(".", maxsplit=1)
190193
parent = model.get_submodule(parent)
191194
setattr(parent, child, new_module)
192-
195+
193196
with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor:
194197
process_fn = partial(process_module, model=model, config=model.config)
195198
list(executor.map(lambda x: process_fn(x[0], x[1]), model.named_modules()))

gptqmodel/nn_modules/qlinear/tritonv2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,12 @@
2323
from ...models._const import DEVICE, PLATFORM
2424
from ...utils.backend import BACKEND
2525
from ...utils.logger import setup_logger
26-
from ...utils.python import has_gil
26+
from ...utils.python import has_gil_disabled
2727
from .torch import TorchQuantLinear
2828

2929
try:
3030
# TODO: triton is not compatible with free threading
31-
if not has_gil():
31+
if not has_gil_disabled():
3232
raise Exception("GIL is disabled so Triton is not (yet) compatible.")
3333

3434
import triton

gptqmodel/utils/__init__.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,24 @@
1616

1717
from .backend import BACKEND
1818
from .logger import setup_logger
19-
from .python import has_gil, log_gil_required
19+
from .python import gte_python_3_13_3, has_gil_control, has_gil_disabled, log_gil_requirements_for
20+
21+
log = setup_logger()
2022

2123
# TODO: datasets is not compatible with free threading
22-
if has_gil():
24+
if has_gil_disabled():
25+
log.info("Python GIL is disabled and GPTQModel will auto enable multi-gpu quant acceleration for MoE models plus multi-cpu accelerated packing.")
2326
from .perplexity import Perplexity
2427
else:
25-
log_gil_required("utils/Perplexity")
28+
if has_gil_control():
29+
log.warn(
30+
"Python >= 3.13T (free-threading) version detected but GIL is not disabled due to manual override or `regex` package compatibility which can be ignored. Please disable GIL via env `PYTHON_GIL=0`.")
31+
32+
log.warn(
33+
"Python GIL is enabled: Multi-gpu quant acceleration for MoE models is sub-optimal and multi-core accelerated cpu packing is also disabled. We recommend Python >= 3.13.3t with Pytorch > 2.8 for mult-gpu quantization and multi-cpu packing with env `PYTHON_GIL=0`.")
34+
35+
log_gil_requirements_for("utils/Perplexity")
36+
37+
2638

2739
from .vram import get_vram

gptqmodel/utils/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
from ..nn_modules.qlinear.ipex import HAS_IPEX, IPEXQuantLinear
5656
from ..quantization import FORMAT, QuantizeConfig
5757
from ..quantization.config import FORMAT_FIELD_JSON, QUANT_METHOD, dynamic_get
58-
from . import has_gil
58+
from . import has_gil_disabled
5959
from .backend import BACKEND
6060
from .importer import select_quant_linear
6161
from .logger import setup_logger
@@ -642,7 +642,7 @@ def pack_model(
642642
names = list(qModules.keys())
643643
lock = threading.Lock()
644644

645-
if not has_gil():
645+
if has_gil_disabled():
646646
from device_smi import Device
647647
cpu = Device("cpu")
648648
max_packers = cpu.count * cpu.cores

gptqmodel/utils/python.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,25 @@
1+
import platform
12
import sys
23

34
from gptqmodel.utils.logger import setup_logger
5+
from packaging.version import Version
46

57
log = setup_logger()
68

9+
# Check if GIL (global interpreter lock) is controllable in this Python build.
10+
# Starting from python 3.13 it is possible to disable GIL
11+
def has_gil_control():
12+
return hasattr(sys, '_is_gil_enabled')
13+
714
# Check if GIL (global interpreter lock) is enabled.
815
# Starting from python 3.13 it is possible to disable GIL
9-
def has_gil():
10-
if hasattr(sys, '_is_gil_enabled'):
11-
return sys._is_gil_enabled()
16+
def has_gil_disabled():
17+
return has_gil_control() and not sys._is_gil_enabled()
1218

13-
return True
19+
# Check For Python > 3.13.3
20+
def gte_python_3_13_3():
21+
return Version(platform.python_version()) >= Version("3.13.3")
1422

15-
def log_gil_required(feature: str):
16-
log.warn.once(f"Feature `{feature}` requires python GIL. Feature is currently skipped/disabled.")
23+
# torch compile requires GIL=1 or python 3.13.3t with GIL=0
24+
def log_gil_requirements_for(feature: str):
25+
log.warn.once(f"Feature `{feature}` requires python GIL or Python > 3.13.3T (T for Threading-Free edition of Python) plus Torch 2.8. Feature is currently skipped/disabled.")

gptqmodel/utils/torch.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from torch.cpu import StreamContext
2424

2525
from ..utils.logger import setup_logger
26-
from . import has_gil, log_gil_required
26+
from . import gte_python_3_13_3, has_gil_disabled, log_gil_requirements_for
2727

2828
# pytorch 2.6.0 fixes many compilation errors
2929
TORCH_HAS_COMPILE = version.parse(torch.__version__).release >= version.Version('2.6').release
@@ -70,10 +70,9 @@ class BalanceStrategy(str, Enum):
7070
pass
7171

7272
def torch_compile(module: Union[torch.nn.Module, Callable], backend:str ="inductor", mode: str = None, fullgraph=False):
73-
# requires torch >2.8 for proper torch.compile
74-
# torch compile broken for free threading
75-
if not has_gil():
76-
log_gil_required("Torch Compile")
73+
# requires torch >2.8 for proper torch.compile + Python 3.13.3t (freethreading)
74+
if has_gil_disabled() and not gte_python_3_13_3():
75+
log_gil_requirements_for("Torch Compile")
7776
return module
7877

7978
#from ..models.base import PYTORCH_MIN_VERSION_WITH_COMPILE

0 commit comments

Comments
 (0)