|
| 1 | + |
| 2 | +import textwrap |
| 3 | + |
| 4 | +import pytest |
| 5 | +from langchain_core.documents import Document |
| 6 | + |
| 7 | +HAPPY_PATH_CASES = [ |
| 8 | + ("simple_function", "func DoSomething() {}", "DoSomething"), |
| 9 | + ("with_parameters", "func DoSomething(p1 string, p2 int) {}", "DoSomething"), |
| 10 | + ("with_return_value", "func DoSomething(v int) string {}", "DoSomething"), |
| 11 | + ( |
| 12 | + "with_named_return", |
| 13 | + "func DoSomething(a, b float64) (q float64, e error) {}", |
| 14 | + "DoSomething", |
| 15 | + ), |
| 16 | + ( |
| 17 | + "method_with_receiver", |
| 18 | + "func (p *Point) DoSomething() float64 {}", |
| 19 | + "DoSomething", |
| 20 | + ), |
| 21 | +] |
| 22 | + |
| 23 | + |
| 24 | +EDGE_CASES_TEST = [ |
| 25 | + ("generic_function", "func DoSomething[T any](s []T) {}", "DoSomething"), |
| 26 | + ( |
| 27 | + "letter_or_underscores", |
| 28 | + "func _internal_calculate_v2() {}", |
| 29 | + "_internal_calculate_v2", |
| 30 | + ), |
| 31 | + ( |
| 32 | + "receivers_double_pointer_function", |
| 33 | + "func (c **Connection) Close() error {}", |
| 34 | + "Close", |
| 35 | + ), |
| 36 | + ( |
| 37 | + "receivers_without_the_name_function", |
| 38 | + "func (*Point) IsOrigin() bool {}", |
| 39 | + "IsOrigin", |
| 40 | + ), |
| 41 | + ( |
| 42 | + "multiline_function", |
| 43 | + """ |
| 44 | + func (r *Repository[ |
| 45 | + T Model, |
| 46 | + K KeyType, |
| 47 | + ]) FindByID(id K) (*T, error) {} |
| 48 | + """, |
| 49 | + "FindByID", |
| 50 | + ), |
| 51 | +] |
| 52 | + |
| 53 | +NEGATIVE_ANONYMOUS_CASES = [ |
| 54 | + ( |
| 55 | + "assigned_to_variable", |
| 56 | + "var greeter = func(name string) { fmt.Println('Hello,', name) }", |
| 57 | + ), |
| 58 | + ( |
| 59 | + "assigned_to_variable2", |
| 60 | + textwrap.dedent( |
| 61 | + """ |
| 62 | + greet := func() { // Assigning anonymous function to a variable 'greet' |
| 63 | + fmt.Println("Greetings from a variable-assigned anonymous function!") |
| 64 | + } |
| 65 | + """ |
| 66 | + ), |
| 67 | + ), |
| 68 | + ( |
| 69 | + "go_routine", |
| 70 | + "go func() { fmt.Println('Running in background') }()", |
| 71 | + ), |
| 72 | + ( |
| 73 | + "defer_statement", |
| 74 | + "defer func() { file.Close() }()", |
| 75 | + ), |
| 76 | + ( |
| 77 | + "callback_argument", |
| 78 | + "http.HandleFunc('/', func(w http.ResponseWriter, r *http.Request) {})", |
| 79 | + ), |
| 80 | +] |
| 81 | + |
| 82 | +MALFORMED_INPUT_CASES = [ |
| 83 | + ("empty_string", ""), |
| 84 | + ("whitespace_only", " \n\t "), |
| 85 | + ("just_the_keyword", "func"), |
| 86 | + ("incomplete_header", "func myFunc("), |
| 87 | + ("garbage_input", "a = b + c;"), |
| 88 | +] |
| 89 | + |
| 90 | +@pytest.mark.parametrize("test_id, code_snippet, expected_name", HAPPY_PATH_CASES) |
| 91 | +def test_happy_path_function_names(go_parser, test_id,code_snippet, expected_name): |
| 92 | + doc = Document(page_content=code_snippet.strip(), metadata={"source": "test.go"}) |
| 93 | + actual_name = go_parser.get_function_name(doc) |
| 94 | + assert actual_name == expected_name, f"Test case '{test_id}' failed" |
| 95 | + |
| 96 | +@pytest.mark.parametrize("test_id, code_snippet, expected_name", EDGE_CASES_TEST) |
| 97 | +def test_edge_cases_function_names(go_parser, test_id, code_snippet, expected_name): |
| 98 | + doc = Document(page_content=code_snippet.strip(), metadata={"source": "test.go"}) |
| 99 | + actual_name = go_parser.get_function_name(doc) |
| 100 | + assert actual_name == expected_name, f"Test case '{test_id}' failed" |
| 101 | + |
| 102 | +@pytest.mark.parametrize("test_id, code_snippet", NEGATIVE_ANONYMOUS_CASES) |
| 103 | +def test_negative_cases_anonymous_functions(go_parser, test_id, code_snippet): |
| 104 | + doc = Document(page_content=code_snippet.strip(), metadata={"source": "proxy.go"}) |
| 105 | + name = go_parser.get_function_name(doc) |
| 106 | + assert name.startswith("anon_"), ( |
| 107 | + f"[{test_id}] Expected name to start with 'anon_', but got '{name}'" |
| 108 | + ) |
| 109 | + parts = name.split("_") |
| 110 | + assert len(parts) == 3, ( |
| 111 | + f"[{test_id}] Expected name format 'anon_<prefix>_<hash>', but got '{name}'" |
| 112 | + ) |
| 113 | + |
| 114 | + assert parts[1] == "proxy", ( |
| 115 | + f"[{test_id}] Expected file prefix 'proxy', but got '{parts[1]}'" |
| 116 | + ) |
| 117 | + hash_part = parts[2] |
| 118 | + assert len(hash_part) == 8, ( |
| 119 | + f"[{test_id}] Hash part should be 8 characters, but got '{hash_part}'" |
| 120 | + ) |
| 121 | + assert all(c in "0123456789abcdef" for c in hash_part), ( |
| 122 | + f"[{test_id}] Hash part should be hex, but got '{hash_part}'" |
| 123 | + ) |
| 124 | + |
| 125 | +@pytest.mark.parametrize("test_id, code_snippet", MALFORMED_INPUT_CASES) |
| 126 | +def test_malformed_input_graceful_failure(go_parser, test_id, code_snippet): |
| 127 | + doc = Document(page_content=code_snippet, metadata={"source": "malformed.go"}) |
| 128 | + name = go_parser.get_function_name(doc) |
| 129 | + |
| 130 | + assert name.startswith("anon_"), ( |
| 131 | + f"[{test_id}] Failed to handle malformed input gracefully. Got: {name}" |
| 132 | + ) |
0 commit comments