Skip to content

Commit 5e006d7

Browse files
wanlin31MarkDaoust
authored andcommitted
test: add table_test and sample for lat/long field support.
PiperOrigin-RevId: 759739814
1 parent c748ad4 commit 5e006d7

File tree

1 file changed

+71
-26
lines changed

1 file changed

+71
-26
lines changed

google/genai/tests/models/test_generate_content_tools.py

Lines changed: 71 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import logging
1717
import sys
1818
import typing
19+
1920
import pydantic
2021
import pytest
2122

@@ -260,6 +261,31 @@ def divide_floats(a: float, b: float) -> float:
260261
config={'tools': [{'code_execution': {}}]},
261262
),
262263
),
264+
pytest_helper.TestTableItem(
265+
name='test_function_google_search_retrieval_with_long_lat',
266+
parameters=types._GenerateContentParameters(
267+
model='gemini-1.5-flash',
268+
contents=t.t_contents(None, 'what is the price of GOOG?'),
269+
config=types.GenerateContentConfig(
270+
tools=[
271+
types.Tool(
272+
google_search_retrieval=types.GoogleSearchRetrieval(
273+
dynamic_retrieval_config=types.DynamicRetrievalConfig(
274+
mode='MODE_UNSPECIFIED'
275+
)
276+
)
277+
),
278+
],
279+
tool_config=types.ToolConfig(
280+
retrieval_config=types.RetrievalConfig(
281+
lat_lng=types.LatLngDict(
282+
latitude=37.7749, longitude=-122.4194
283+
)
284+
)
285+
),
286+
),
287+
),
288+
),
263289
pytest_helper.TestTableItem(
264290
name='test_url_context',
265291
parameters=types._GenerateContentParameters(
@@ -477,7 +503,9 @@ async def test_disable_automatic_function_calling_stream_async(client):
477503

478504

479505
@pytest.mark.asyncio
480-
async def test_automatic_function_calling_no_function_response_stream_async(client):
506+
async def test_automatic_function_calling_no_function_response_stream_async(
507+
client,
508+
):
481509
response = await client.aio.models.generate_content_stream(
482510
model='gemini-1.5-flash',
483511
contents='what is the weather in Boston?',
@@ -671,7 +699,7 @@ def get_weather_pydantic_model(
671699
contents='it is winter now, what is the weather in Boston?',
672700
config={
673701
'tools': [get_weather_pydantic_model],
674-
'automatic_function_calling': {'ignore_call_history': True}
702+
'automatic_function_calling': {'ignore_call_history': True},
675703
},
676704
)
677705

@@ -702,8 +730,9 @@ def get_weather_from_list_of_cities(
702730
response = client.models.generate_content(
703731
model='gemini-1.5-flash',
704732
contents='it is winter now, what is the weather in Boston and New York?',
705-
config={'tools': [get_weather_from_list_of_cities],
706-
'automatic_function_calling': {'ignore_call_history': True},
733+
config={
734+
'tools': [get_weather_from_list_of_cities],
735+
'automatic_function_calling': {'ignore_call_history': True},
707736
},
708737
)
709738

@@ -749,30 +778,34 @@ def get_information(
749778
),
750779
config={
751780
'tools': [get_information],
752-
'automatic_function_calling': {'ignore_call_history': True}
781+
'automatic_function_calling': {'ignore_call_history': True},
753782
},
754783
)
755784
assert 'Sundae' in response.text
756785
assert 'cat' in response.text
757786

758787

759-
def test_automatic_function_calling_with_parameterized_generic_union_type(client):
788+
def test_automatic_function_calling_with_parameterized_generic_union_type(
789+
client,
790+
):
760791
def describe_cities(
761792
country: str,
762793
cities: typing.Optional[list[str]] = None,
763794
) -> str:
764-
"Given a country and an optional list of cities, describe the cities."
795+
'Given a country and an optional list of cities, describe the cities.'
765796
if cities is None:
766797
return 'There are no cities to describe.'
767798
else:
768-
return f'The cities in {country} are: {", ".join(cities)} and they are nice.'
799+
return (
800+
f'The cities in {country} are: {", ".join(cities)} and they are nice.'
801+
)
769802

770803
response = client.models.generate_content(
771804
model='gemini-1.5-flash',
772-
contents=('Can you describe the city of San Francisco?'),
805+
contents='Can you describe the city of San Francisco?',
773806
config={
774807
'tools': [describe_cities],
775-
'automatic_function_calling': {'ignore_call_history': True}
808+
'automatic_function_calling': {'ignore_call_history': True},
776809
},
777810
)
778811
assert 'San Francisco' in response.text
@@ -807,7 +840,7 @@ def test_with_1_empty_tool(client):
807840
contents='What is the price of GOOG?.',
808841
config={
809842
'tools': [{}, get_stock_price],
810-
'automatic_function_calling': {'ignore_call_history': True}
843+
'automatic_function_calling': {'ignore_call_history': True},
811844
},
812845
)
813846

@@ -832,7 +865,9 @@ async def test_vai_search_stream_async(client):
832865
'tools': [{
833866
'retrieval': {
834867
'vertex_ai_search': {
835-
'datastore': 'projects/vertex-sdk-dev/locations/global/collections/default_collection/dataStores/yvonne_1728691676574'
868+
'datastore': (
869+
'projects/vertex-sdk-dev/locations/global/collections/default_collection/dataStores/yvonne_1728691676574'
870+
)
836871
}
837872
}
838873
}]
@@ -848,7 +883,9 @@ async def test_vai_search_stream_async(client):
848883
'tools': [{
849884
'retrieval': {
850885
'vertex_ai_search': {
851-
'datastore': 'projects/vertex-sdk-dev/locations/global/collections/default_collection/dataStores/yvonne_1728691676574'
886+
'datastore': (
887+
'projects/vertex-sdk-dev/locations/global/collections/default_collection/dataStores/yvonne_1728691676574'
888+
)
852889
}
853890
}
854891
}]
@@ -868,7 +905,7 @@ async def divide_integers(a: int, b: int) -> int:
868905
contents='what is the result of 1000/2?',
869906
config={
870907
'tools': [divide_integers],
871-
'automatic_function_calling': {'ignore_call_history': True}
908+
'automatic_function_calling': {'ignore_call_history': True},
872909
},
873910
)
874911

@@ -881,13 +918,13 @@ async def divide_integers(a: int, b: int) -> int:
881918
return a // b
882919

883920
response = await client.aio.models.generate_content(
884-
model='gemini-1.5-flash',
885-
contents='what is the result of 1000/2?',
886-
config={
887-
'tools': [divide_integers],
888-
'automatic_function_calling': {'ignore_call_history': True}
889-
},
890-
)
921+
model='gemini-1.5-flash',
922+
contents='what is the result of 1000/2?',
923+
config={
924+
'tools': [divide_integers],
925+
'automatic_function_calling': {'ignore_call_history': True},
926+
},
927+
)
891928

892929
assert '500' in response.text
893930

@@ -902,7 +939,7 @@ def divide_integers(a: int, b: int) -> int:
902939
contents='what is the result of 1000/2?',
903940
config={
904941
'tools': [divide_integers],
905-
'automatic_function_calling': {'ignore_call_history': True}
942+
'automatic_function_calling': {'ignore_call_history': True},
906943
},
907944
)
908945

@@ -920,7 +957,12 @@ def mystery_function(a: int, b: int) -> int:
920957
config={'tools': [divide_integers]},
921958
)
922959
assert response.automatic_function_calling_history
923-
assert response.automatic_function_calling_history[-1].parts[0].function_response.response['error']
960+
assert (
961+
response.automatic_function_calling_history[-1]
962+
.parts[0]
963+
.function_response.response['error']
964+
)
965+
924966

925967
@pytest.mark.asyncio
926968
async def test_automatic_function_calling_async_float_without_decimal(client):
@@ -984,7 +1026,9 @@ async def get_current_weather_async(city: str) -> str:
9841026

9851027

9861028
@pytest.mark.asyncio
987-
async def test_automatic_function_calling_async_with_async_function_stream(client):
1029+
async def test_automatic_function_calling_async_with_async_function_stream(
1030+
client,
1031+
):
9881032
async def get_current_weather_async(city: str) -> str:
9891033
"""Returns the current weather in the city."""
9901034

@@ -1195,8 +1239,9 @@ def test_code_execution_tool(client):
11951239

11961240
assert response.executable_code
11971241
assert (
1198-
'prime' in response.code_execution_result.lower() or
1199-
'5117' in response.code_execution_result)
1242+
'prime' in response.code_execution_result.lower()
1243+
or '5117' in response.code_execution_result
1244+
)
12001245

12011246

12021247
def test_afc_logs_to_logger_instance(client, caplog):

0 commit comments

Comments
 (0)