Skip to content

Commit 9261a8c

Browse files
committed
refactor(go-parser): add unit tests for Go function naming
Signed-off-by: Vladimir Belousov <[email protected]>
1 parent f1f582d commit 9261a8c

File tree

3 files changed

+189
-37
lines changed

3 files changed

+189
-37
lines changed
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import pytest
2+
3+
from vuln_analysis.utils.functions_parsers.golang_functions_parsers import (
4+
GoLanguageFunctionsParser,
5+
)
6+
7+
8+
@pytest.fixture(scope="module")
9+
def go_parser() -> GoLanguageFunctionsParser:
10+
"""
11+
Provides a single instance of the GoLanguageFunctionsParser
12+
for all tests in a module.
13+
"""
14+
return GoLanguageFunctionsParser()
Lines changed: 43 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,65 +1,60 @@
1+
import textwrap
2+
13
from vuln_analysis.utils.go_segmenter_extended import GoSegmenterExtended
24

35

46
def _extract(code: str):
5-
seg = GoSegmenterExtended(code)
6-
return [s.strip() for s in seg.extract_functions_classes()]
7+
seg = GoSegmenterExtended(textwrap.dedent(code))
8+
return seg.extract_functions_classes()
79

810

9-
def test_generic_method_basic():
11+
def test_segmenter_extracts_type_and_generic_method():
1012
code = """
1113
type Box[T any] struct { value T }
1214
func (b *Box[T]) Set(v T) { b.value = v }
1315
"""
14-
chunks = _extract(code)
15-
assert any("Set" in c for c in chunks), "generic method not extracted"
16+
expected_chunks = [
17+
"type Box[T any] struct { value T }",
18+
"func (b *Box[T]) Set(v T) { b.value = v }",
19+
]
1620

17-
18-
def test_generic_multiple_type_params():
19-
code = """
20-
func MapKeys[K comparable, V any](m map[K]V) []K {
21-
keys := make([]K, 0, len(m))
22-
for k := range m {
23-
keys = append(keys, k)
24-
}
25-
return keys
26-
}
27-
"""
28-
chunks = _extract(code)
29-
assert any("MapKeys" in c for c in chunks), "multiple generics not parsed"
21+
actual_chunks = [c.strip() for c in _extract(code)]
22+
assert actual_chunks == expected_chunks
3023

3124

32-
def test_function_returning_func():
25+
def test_segmenter_extracts_toplevel_function_only_and_ignores_nested():
3326
code = """
3427
func makeAdder(x int) func(int) int {
3528
return func(y int) int { return x + y }
3629
}
3730
"""
38-
chunks = _extract(code)
39-
assert any("makeAdder" in c for c in chunks), "failed to parse func returning func"
40-
31+
expected_chunks = [
32+
textwrap.dedent("""
33+
func makeAdder(x int) func(int) int {
34+
return func(y int) int { return x + y }
35+
}
36+
""").strip()
37+
]
4138

42-
def test_inline_anonymous_func():
43-
code = """
44-
func Worker() {
45-
defer func() { cleanup() }()
46-
go func() { runTask() }()
47-
}
48-
"""
49-
chunks = _extract(code)
50-
assert any("Worker" in c for c in chunks), "missed inline anonymous func"
39+
actual_chunks = [c.strip() for c in _extract(code)]
40+
assert actual_chunks == expected_chunks
5141

5242

53-
def test_double_pointer_receiver():
43+
def test_segmenter_handles_double_pointer_receiver():
5444
code = """
5545
type Conn struct{}
5646
func (c **Conn) Reset() {}
5747
"""
58-
chunks = _extract(code)
59-
assert any("Reset" in c for c in chunks), "failed to detect pointer receiver"
48+
expected_chunks = [
49+
"type Conn struct{}",
50+
"func (c **Conn) Reset() {}",
51+
]
6052

53+
actual_chunks = [c.strip() for c in _extract(code)]
54+
assert actual_chunks == expected_chunks
6155

62-
def test_multiline_generic_method():
56+
57+
def test_segmenter_handles_multiline_generic_method():
6358
code = """
6459
func (r *Repo[
6560
T any,
@@ -68,5 +63,16 @@ def test_multiline_generic_method():
6863
return nil, nil
6964
}
7065
"""
71-
chunks = _extract(code)
72-
assert any("Save" in c for c in chunks), "multiline generic method not parsed"
66+
expected_chunks = [
67+
textwrap.dedent("""
68+
func (r *Repo[
69+
T any,
70+
E error,
71+
]) Save(v T) (E, error) {
72+
return nil, nil
73+
}
74+
""").strip()
75+
]
76+
77+
actual_chunks = [c.strip() for c in _extract(code)]
78+
assert actual_chunks == expected_chunks
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
2+
import textwrap
3+
4+
import pytest
5+
from langchain_core.documents import Document
6+
7+
HAPPY_PATH_CASES = [
8+
("simple_function", "func DoSomething() {}", "DoSomething"),
9+
("with_parameters", "func DoSomething(p1 string, p2 int) {}", "DoSomething"),
10+
("with_return_value", "func DoSomething(v int) string {}", "DoSomething"),
11+
(
12+
"with_named_return",
13+
"func DoSomething(a, b float64) (q float64, e error) {}",
14+
"DoSomething",
15+
),
16+
(
17+
"method_with_receiver",
18+
"func (p *Point) DoSomething() float64 {}",
19+
"DoSomething",
20+
),
21+
]
22+
23+
24+
EDGE_CASES_TEST = [
25+
("generic_function", "func DoSomething[T any](s []T) {}", "DoSomething"),
26+
(
27+
"letter_or_underscores",
28+
"func _internal_calculate_v2() {}",
29+
"_internal_calculate_v2",
30+
),
31+
(
32+
"receivers_double_pointer_function",
33+
"func (c **Connection) Close() error {}",
34+
"Close",
35+
),
36+
(
37+
"receivers_without_the_name_function",
38+
"func (*Point) IsOrigin() bool {}",
39+
"IsOrigin",
40+
),
41+
(
42+
"multiline_function",
43+
"""
44+
func (r *Repository[
45+
T Model,
46+
K KeyType,
47+
]) FindByID(id K) (*T, error) {}
48+
""",
49+
"FindByID",
50+
),
51+
]
52+
53+
NEGATIVE_ANONYMOUS_CASES = [
54+
(
55+
"assigned_to_variable",
56+
"var greeter = func(name string) { fmt.Println('Hello,', name) }",
57+
),
58+
(
59+
"assigned_to_variable2",
60+
textwrap.dedent(
61+
"""
62+
greet := func() { // Assigning anonymous function to a variable 'greet'
63+
fmt.Println("Greetings from a variable-assigned anonymous function!")
64+
}
65+
"""
66+
),
67+
),
68+
(
69+
"go_routine",
70+
"go func() { fmt.Println('Running in background') }()",
71+
),
72+
(
73+
"defer_statement",
74+
"defer func() { file.Close() }()",
75+
),
76+
(
77+
"callback_argument",
78+
"http.HandleFunc('/', func(w http.ResponseWriter, r *http.Request) {})",
79+
),
80+
]
81+
82+
MALFORMED_INPUT_CASES = [
83+
("empty_string", ""),
84+
("whitespace_only", " \n\t "),
85+
("just_the_keyword", "func"),
86+
("incomplete_header", "func myFunc("),
87+
("garbage_input", "a = b + c;"),
88+
]
89+
90+
@pytest.mark.parametrize("test_id, code_snippet, expected_name", HAPPY_PATH_CASES)
91+
def test_happy_path_function_names(go_parser, test_id,code_snippet, expected_name):
92+
doc = Document(page_content=code_snippet.strip(), metadata={"source": "test.go"})
93+
actual_name = go_parser.get_function_name(doc)
94+
assert actual_name == expected_name, f"Test case '{test_id}' failed"
95+
96+
@pytest.mark.parametrize("test_id, code_snippet, expected_name", EDGE_CASES_TEST)
97+
def test_edge_cases_function_names(go_parser, test_id, code_snippet, expected_name):
98+
doc = Document(page_content=code_snippet.strip(), metadata={"source": "test.go"})
99+
actual_name = go_parser.get_function_name(doc)
100+
assert actual_name == expected_name, f"Test case '{test_id}' failed"
101+
102+
@pytest.mark.parametrize("test_id, code_snippet", NEGATIVE_ANONYMOUS_CASES)
103+
def test_negative_cases_anonymous_functions(go_parser, test_id, code_snippet):
104+
doc = Document(page_content=code_snippet.strip(), metadata={"source": "proxy.go"})
105+
name = go_parser.get_function_name(doc)
106+
assert name.startswith("anon_"), (
107+
f"[{test_id}] Expected name to start with 'anon_', but got '{name}'"
108+
)
109+
parts = name.split("_")
110+
assert len(parts) == 3, (
111+
f"[{test_id}] Expected name format 'anon_<prefix>_<hash>', but got '{name}'"
112+
)
113+
114+
assert parts[1] == "proxy", (
115+
f"[{test_id}] Expected file prefix 'proxy', but got '{parts[1]}'"
116+
)
117+
hash_part = parts[2]
118+
assert len(hash_part) == 8, (
119+
f"[{test_id}] Hash part should be 8 characters, but got '{hash_part}'"
120+
)
121+
assert all(c in "0123456789abcdef" for c in hash_part), (
122+
f"[{test_id}] Hash part should be hex, but got '{hash_part}'"
123+
)
124+
125+
@pytest.mark.parametrize("test_id, code_snippet", MALFORMED_INPUT_CASES)
126+
def test_malformed_input_graceful_failure(go_parser, test_id, code_snippet):
127+
doc = Document(page_content=code_snippet, metadata={"source": "malformed.go"})
128+
name = go_parser.get_function_name(doc)
129+
130+
assert name.startswith("anon_"), (
131+
f"[{test_id}] Failed to handle malformed input gracefully. Got: {name}"
132+
)

0 commit comments

Comments
 (0)