Skip to content

Commit ecc17be

Browse files
authored
Add gpu ut (#370)
1 parent e4528e9 commit ecc17be

File tree

5 files changed

+615
-24
lines changed

5 files changed

+615
-24
lines changed

test_cuda/test_2_3bits.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
import copy
2+
import shutil
3+
import sys
4+
import unittest
5+
import re
6+
7+
sys.path.insert(0, "..")
8+
import torch
9+
import transformers
10+
from transformers import AutoModelForCausalLM, AutoTokenizer
11+
12+
from auto_round import AutoRound
13+
from auto_round.eval.evaluation import simple_evaluate
14+
from lm_eval.utils import make_table # pylint: disable=E0401
15+
16+
17+
def get_accuracy(data):
18+
match = re.search(r'\|acc\s+\|[↑↓]\s+\|\s+([\d.]+)\|', data)
19+
20+
if match:
21+
accuracy = float(match.group(1))
22+
return accuracy
23+
else:
24+
return 0.0
25+
26+
27+
class TestAutoRound(unittest.TestCase):
28+
@classmethod
29+
def setUpClass(self):
30+
self.save_dir = "./saved"
31+
self.tasks = "lambada_openai"
32+
33+
@classmethod
34+
def tearDownClass(self):
35+
shutil.rmtree("./saved", ignore_errors=True)
36+
shutil.rmtree("runs", ignore_errors=True)
37+
38+
def test_3bits_autogptq(self):
39+
model_name = "/models/opt-125m"
40+
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")
41+
tokenizer = AutoTokenizer.from_pretrained(model_name)
42+
autoround = AutoRound(model, tokenizer, bits=3)
43+
autoround.quantize()
44+
45+
autoround.save_quantized(self.save_dir, format="auto_gptq", inplace=False)
46+
model_args = f"pretrained={self.save_dir}"
47+
res = simple_evaluate(model="hf", model_args=model_args,
48+
tasks=self.tasks,
49+
batch_size="auto")
50+
res = make_table(res)
51+
52+
accuracy = get_accuracy(res)
53+
assert accuracy > 0.30
54+
shutil.rmtree("./saved", ignore_errors=True)
55+
56+
def test_norm_bias_tuning(self):
57+
model_name = "/models/opt-125m"
58+
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")
59+
tokenizer = AutoTokenizer.from_pretrained(model_name)
60+
autoround = AutoRound(model, tokenizer, bits=2, group_size=64, enable_norm_bias_tuning=True)
61+
autoround.quantize()
62+
63+
##test auto_round format
64+
autoround.save_quantized(self.save_dir, format="auto_round", inplace=False)
65+
model_args = f"pretrained={self.save_dir}"
66+
res = simple_evaluate(model="hf", model_args=model_args,
67+
tasks=self.tasks,
68+
batch_size="auto")
69+
res = make_table(res) ##0.2212
70+
accuracy = get_accuracy(res)
71+
assert accuracy > 0.20
72+
shutil.rmtree("./saved", ignore_errors=True)
73+
74+
def test_2bits_autoround(self):
75+
model_name = "/models/opt-125m"
76+
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")
77+
tokenizer = AutoTokenizer.from_pretrained(model_name)
78+
autoround = AutoRound(model, tokenizer, bits=2, group_size=64)
79+
autoround.quantize()
80+
81+
##test auto_round format
82+
autoround.save_quantized(self.save_dir, format="auto_round", inplace=False)
83+
model_args = f"pretrained={self.save_dir}"
84+
res = simple_evaluate(model="hf", model_args=model_args,
85+
tasks=self.tasks,
86+
batch_size="auto")
87+
res = make_table(res) ##0.1985
88+
accuracy = get_accuracy(res)
89+
assert accuracy > 0.18
90+
shutil.rmtree("./saved", ignore_errors=True)
91+
92+
93+
autoround.save_quantized(self.save_dir, format="auto_gptq", inplace=False)
94+
model_args = f"pretrained={self.save_dir}"
95+
res = simple_evaluate(model="hf", model_args=model_args,
96+
tasks=self.tasks,
97+
batch_size="auto")
98+
res = make_table(res) ##0.1985
99+
accuracy = get_accuracy(res)
100+
assert accuracy > 0.18
101+
shutil.rmtree("./saved", ignore_errors=True)
102+
103+
if __name__ == "__main__":
104+
unittest.main()

test/test_cuda_before_release.py renamed to test_cuda/test_main_func.py

Lines changed: 53 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import transformers
1010
from transformers import AutoModelForCausalLM, AutoTokenizer
1111

12-
from auto_round import AutoRound
12+
from auto_round import AutoRound, AutoRoundAdam
1313
from auto_round.eval.evaluation import simple_evaluate
1414
from lm_eval.utils import make_table # pylint: disable=E0401
1515

@@ -24,7 +24,7 @@ def get_accuracy(data):
2424
return 0.0
2525

2626

27-
class TestAutoRound(unittest.TestCase):
27+
class TestMainFunc(unittest.TestCase):
2828
@classmethod
2929
def setUpClass(self):
3030
self.save_dir = "./saved"
@@ -35,7 +35,6 @@ def tearDownClass(self):
3535
shutil.rmtree("./saved", ignore_errors=True)
3636
shutil.rmtree("runs", ignore_errors=True)
3737

38-
@unittest.skipIf(torch.cuda.is_available() is False, "Skipping because no cuda")
3938
def test_backend(self):
4039
model_name = "/models/opt-125m"
4140
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")
@@ -54,7 +53,7 @@ def test_backend(self):
5453
assert accuracy > 0.35
5554
shutil.rmtree("./saved", ignore_errors=True)
5655

57-
##test auto_round format
56+
##test auto_gptq format
5857
autoround.save_quantized(self.save_dir, format="auto_gptq", inplace=False)
5958
model_args = f"pretrained={self.save_dir}"
6059
res = simple_evaluate(model="hf", model_args=model_args,
@@ -65,7 +64,7 @@ def test_backend(self):
6564
assert accuracy > 0.35
6665
shutil.rmtree("./saved", ignore_errors=True)
6766

68-
##test auto_round format
67+
##test auto_awq format
6968
autoround.save_quantized(self.save_dir, format="auto_awq", inplace=False)
7069
model_args = f"pretrained={self.save_dir}"
7170
res = simple_evaluate(model="hf", model_args=model_args,
@@ -113,27 +112,57 @@ def test_fp_layers(self):
113112

114113
@unittest.skipIf(torch.cuda.is_available() is False, "Skipping because no cuda")
115114
def test_undivided_group_size_tuning(self):
116-
model_name = "/models/falcon-7b"
115+
model_name = "/models/opt-125m"
117116
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")
118117
tokenizer = AutoTokenizer.from_pretrained(model_name)
119118

120-
autoround = AutoRound(model, tokenizer, bits=4, group_size=128, nsamples=1, iters=1)
119+
autoround = AutoRound(model, tokenizer, bits=4, group_size=127, nsamples=2, iters=2)
120+
autoround.quantize()
121+
122+
123+
def test_adam(self):
124+
model_name = "/models/opt-125m"
125+
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")
126+
tokenizer = AutoTokenizer.from_pretrained(model_name)
127+
autoround = AutoRoundAdam(model, tokenizer, bits=4, group_size=128)
128+
autoround.quantize()
129+
130+
##test auto_round format
131+
autoround.save_quantized(self.save_dir, format="auto_round", inplace=False)
132+
model_args = f"pretrained={self.save_dir}"
133+
res = simple_evaluate(model="hf", model_args=model_args,
134+
tasks=self.tasks,
135+
batch_size="auto")
136+
res = make_table(res)
137+
accuracy = get_accuracy(res)
138+
assert accuracy > 0.35
139+
shutil.rmtree("./saved", ignore_errors=True)
140+
141+
def test_autoround_asym(self): ##need to install false
142+
try:
143+
from autoround_exllamav2_kernels import gemm_half_q_half, make_q_matrix
144+
except ImportError as e:
145+
print("skip autoround asym test, as autoround is not installed from source")
146+
return
147+
model_name = "/models/opt-125m"
148+
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")
149+
tokenizer = AutoTokenizer.from_pretrained(model_name)
150+
autoround = AutoRound(model, tokenizer, bits=4, group_size=128, sym=False)
121151
autoround.quantize()
152+
153+
##test auto_round format
154+
autoround.save_quantized(self.save_dir, format="auto_round", inplace=False)
155+
model_args = f"pretrained={self.save_dir}"
156+
res = simple_evaluate(model="hf", model_args=model_args,
157+
tasks=self.tasks,
158+
batch_size="auto")
159+
res = make_table(res)
160+
accuracy = get_accuracy(res)
161+
assert accuracy > 0.35
162+
shutil.rmtree("./saved", ignore_errors=True)
163+
164+
165+
122166

123-
@unittest.skipIf(torch.cuda.is_available() is False, "Skipping because no cuda")
124-
def test_vision_generation(self):
125-
quantized_model_path = "OPEA/Phi-3.5-vision-instruct-qvision-int4-sym-inc"
126-
from auto_round import AutoRoundConfig
127-
device = "auto" ##cpu, hpu, cuda
128-
quantization_config = AutoRoundConfig(
129-
backend=device
130-
)
131-
model = AutoModelForCausalLM.from_pretrained(quantized_model_path, trust_remote_code=True,
132-
device_map=device, quantization_config=quantization_config)
133-
tokenizer = AutoTokenizer.from_pretrained(quantized_model_path)
134-
text = "There is a girl who likes adventure,"
135-
inputs = tokenizer(text, return_tensors="pt").to(model.device)
136-
res = tokenizer.decode(model.generate(**inputs, max_new_tokens=50)[0])
137-
print(res)
138-
assert (
139-
res == """<s> There is a girl who likes adventure, and she is looking for a partner to go on a treasure hunt. She has found a map that leads to a hidden treasure, but she needs a partner to help her decipher the clues and find the treasure. You""")
167+
if __name__ == "__main__":
168+
unittest.main()

test_cuda/test_multiple_card_calib.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import copy
2+
import shutil
3+
import sys
4+
import unittest
5+
import re
6+
7+
sys.path.insert(0, "..")
8+
import torch
9+
import transformers
10+
from transformers import AutoModelForCausalLM, AutoTokenizer
11+
12+
from auto_round import AutoRound
13+
from auto_round.eval.evaluation import simple_evaluate
14+
from lm_eval.utils import make_table # pylint: disable=E0401
15+
import os
16+
17+
def get_accuracy(data):
18+
match = re.search(r'\|acc\s+\|[↑↓]\s+\|\s+([\d.]+)\|', data)
19+
20+
if match:
21+
accuracy = float(match.group(1))
22+
return accuracy
23+
else:
24+
return 0.0
25+
26+
27+
class TestAutoRound(unittest.TestCase):
28+
@classmethod
29+
def setUpClass(self):
30+
self.save_dir = "./saved"
31+
self.tasks = "lambada_openai"
32+
33+
@classmethod
34+
def tearDownClass(self):
35+
shutil.rmtree("./saved", ignore_errors=True)
36+
shutil.rmtree("runs", ignore_errors=True)
37+
38+
def test_multiple_card_calib(self):
39+
python_path = sys.executable
40+
41+
##test llm script
42+
res = os.system(
43+
f"cd .. && {python_path} -m auto_round --model /models/Meta-Llama-3.1-8B-Instruct --devices '0,1' --quant_lm_head --disable_eval --iters 1 --nsamples 1 --output_dir None")
44+
45+
46+
if __name__ == "__main__":
47+
unittest.main()
48+
49+

0 commit comments

Comments
 (0)