Skip to content

Commit e3a0143

Browse files
wanlin31copybara-github
authored andcommitted
test: add table_test and sample for lat/long field support.
PiperOrigin-RevId: 750384492
1 parent b49ccb0 commit e3a0143

File tree

3 files changed

+101
-38
lines changed

3 files changed

+101
-38
lines changed

google/genai/caches.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -290,10 +290,10 @@ def _LatLng_to_mldev(
290290
) -> dict[str, Any]:
291291
to_object: dict[str, Any] = {}
292292
if getv(from_object, ['latitude']) is not None:
293-
raise ValueError('latitude parameter is not supported in Gemini API.')
293+
setv(to_object, ['latitude'], getv(from_object, ['latitude']))
294294

295295
if getv(from_object, ['longitude']) is not None:
296-
raise ValueError('longitude parameter is not supported in Gemini API.')
296+
setv(to_object, ['longitude'], getv(from_object, ['longitude']))
297297

298298
return to_object
299299

@@ -305,7 +305,11 @@ def _RetrievalConfig_to_mldev(
305305
) -> dict[str, Any]:
306306
to_object: dict[str, Any] = {}
307307
if getv(from_object, ['lat_lng']) is not None:
308-
raise ValueError('lat_lng parameter is not supported in Gemini API.')
308+
setv(
309+
to_object,
310+
['latLng'],
311+
_LatLng_to_mldev(api_client, getv(from_object, ['lat_lng']), to_object),
312+
)
309313

310314
return to_object
311315

@@ -328,8 +332,12 @@ def _ToolConfig_to_mldev(
328332
)
329333

330334
if getv(from_object, ['retrieval_config']) is not None:
331-
raise ValueError(
332-
'retrieval_config parameter is not supported in Gemini API.'
335+
setv(
336+
to_object,
337+
['retrievalConfig'],
338+
_RetrievalConfig_to_mldev(
339+
api_client, getv(from_object, ['retrieval_config']), to_object
340+
),
333341
)
334342

335343
return to_object

google/genai/models.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -323,10 +323,10 @@ def _LatLng_to_mldev(
323323
) -> dict[str, Any]:
324324
to_object: dict[str, Any] = {}
325325
if getv(from_object, ['latitude']) is not None:
326-
raise ValueError('latitude parameter is not supported in Gemini API.')
326+
setv(to_object, ['latitude'], getv(from_object, ['latitude']))
327327

328328
if getv(from_object, ['longitude']) is not None:
329-
raise ValueError('longitude parameter is not supported in Gemini API.')
329+
setv(to_object, ['longitude'], getv(from_object, ['longitude']))
330330

331331
return to_object
332332

@@ -338,7 +338,11 @@ def _RetrievalConfig_to_mldev(
338338
) -> dict[str, Any]:
339339
to_object: dict[str, Any] = {}
340340
if getv(from_object, ['lat_lng']) is not None:
341-
raise ValueError('lat_lng parameter is not supported in Gemini API.')
341+
setv(
342+
to_object,
343+
['latLng'],
344+
_LatLng_to_mldev(api_client, getv(from_object, ['lat_lng']), to_object),
345+
)
342346

343347
return to_object
344348

@@ -361,8 +365,12 @@ def _ToolConfig_to_mldev(
361365
)
362366

363367
if getv(from_object, ['retrieval_config']) is not None:
364-
raise ValueError(
365-
'retrieval_config parameter is not supported in Gemini API.'
368+
setv(
369+
to_object,
370+
['retrievalConfig'],
371+
_RetrievalConfig_to_mldev(
372+
api_client, getv(from_object, ['retrieval_config']), to_object
373+
),
366374
)
367375

368376
return to_object

google/genai/tests/models/test_generate_content_tools.py

+75-28
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import logging
1717
import sys
1818
import typing
19+
1920
import pydantic
2021
import pytest
2122

@@ -120,7 +121,9 @@ def divide_floats(a: float, b: float) -> float:
120121
'tools': [{
121122
'retrieval': {
122123
'vertex_ai_search': {
123-
'datastore': 'projects/vertex-sdk-dev/locations/global/collections/default_collection/dataStores/yvonne_1728691676574'
124+
'datastore': (
125+
'projects/vertex-sdk-dev/locations/global/collections/default_collection/dataStores/yvonne_1728691676574'
126+
)
124127
}
125128
}
126129
}]
@@ -149,7 +152,7 @@ def divide_floats(a: float, b: float) -> float:
149152
exception_if_mldev='retrieval',
150153
exception_if_vertex='400',
151154
),
152-
pytest_helper.TestTableItem(
155+
pytest_helper.TestTableItem(
153156
name='test_vai_search_engine',
154157
parameters=types._GenerateContentParameters(
155158
model='gemini-2.0-flash-001',
@@ -258,6 +261,31 @@ def divide_floats(a: float, b: float) -> float:
258261
config={'tools': [{'code_execution': {}}]},
259262
),
260263
),
264+
pytest_helper.TestTableItem(
265+
name='test_function_google_search_retrieval_with_long_lat',
266+
parameters=types._GenerateContentParameters(
267+
model='gemini-1.5-flash',
268+
contents=t.t_contents(None, 'what is the price of GOOG?'),
269+
config=types.GenerateContentConfig(
270+
tools=[
271+
types.Tool(
272+
google_search_retrieval=types.GoogleSearchRetrieval(
273+
dynamic_retrieval_config=types.DynamicRetrievalConfig(
274+
mode='MODE_UNSPECIFIED'
275+
)
276+
)
277+
),
278+
],
279+
tool_config=types.ToolConfig(
280+
retrieval_config=types.RetrievalConfig(
281+
lat_lng=types.LatLngDict(
282+
latitude=37.7749, longitude=-122.4194
283+
)
284+
)
285+
),
286+
),
287+
),
288+
),
261289
]
262290

263291

@@ -464,7 +492,9 @@ async def test_disable_automatic_function_calling_stream_async(client):
464492

465493

466494
@pytest.mark.asyncio
467-
async def test_automatic_function_calling_no_function_response_stream_async(client):
495+
async def test_automatic_function_calling_no_function_response_stream_async(
496+
client,
497+
):
468498
response = await client.aio.models.generate_content_stream(
469499
model='gemini-1.5-flash',
470500
contents='what is the weather in Boston?',
@@ -658,7 +688,7 @@ def get_weather_pydantic_model(
658688
contents='it is winter now, what is the weather in Boston?',
659689
config={
660690
'tools': [get_weather_pydantic_model],
661-
'automatic_function_calling': {'ignore_call_history': True}
691+
'automatic_function_calling': {'ignore_call_history': True},
662692
},
663693
)
664694

@@ -689,8 +719,9 @@ def get_weather_from_list_of_cities(
689719
response = client.models.generate_content(
690720
model='gemini-1.5-flash',
691721
contents='it is winter now, what is the weather in Boston and New York?',
692-
config={'tools': [get_weather_from_list_of_cities],
693-
'automatic_function_calling': {'ignore_call_history': True},
722+
config={
723+
'tools': [get_weather_from_list_of_cities],
724+
'automatic_function_calling': {'ignore_call_history': True},
694725
},
695726
)
696727

@@ -736,30 +767,34 @@ def get_information(
736767
),
737768
config={
738769
'tools': [get_information],
739-
'automatic_function_calling': {'ignore_call_history': True}
770+
'automatic_function_calling': {'ignore_call_history': True},
740771
},
741772
)
742773
assert 'Sundae' in response.text
743774
assert 'cat' in response.text
744775

745776

746-
def test_automatic_function_calling_with_parameterized_generic_union_type(client):
777+
def test_automatic_function_calling_with_parameterized_generic_union_type(
778+
client,
779+
):
747780
def describe_cities(
748781
country: str,
749782
cities: typing.Optional[list[str]] = None,
750783
) -> str:
751-
"Given a country and an optional list of cities, describe the cities."
784+
'Given a country and an optional list of cities, describe the cities.'
752785
if cities is None:
753786
return 'There are no cities to describe.'
754787
else:
755-
return f'The cities in {country} are: {", ".join(cities)} and they are nice.'
788+
return (
789+
f'The cities in {country} are: {", ".join(cities)} and they are nice.'
790+
)
756791

757792
response = client.models.generate_content(
758793
model='gemini-1.5-flash',
759-
contents=('Can you describe the city of San Francisco?'),
794+
contents='Can you describe the city of San Francisco?',
760795
config={
761796
'tools': [describe_cities],
762-
'automatic_function_calling': {'ignore_call_history': True}
797+
'automatic_function_calling': {'ignore_call_history': True},
763798
},
764799
)
765800
assert 'San Francisco' in response.text
@@ -794,7 +829,7 @@ def test_with_1_empty_tool(client):
794829
contents='What is the price of GOOG?.',
795830
config={
796831
'tools': [{}, get_stock_price],
797-
'automatic_function_calling': {'ignore_call_history': True}
832+
'automatic_function_calling': {'ignore_call_history': True},
798833
},
799834
)
800835

@@ -819,7 +854,9 @@ async def test_vai_search_stream_async(client):
819854
'tools': [{
820855
'retrieval': {
821856
'vertex_ai_search': {
822-
'datastore': 'projects/vertex-sdk-dev/locations/global/collections/default_collection/dataStores/yvonne_1728691676574'
857+
'datastore': (
858+
'projects/vertex-sdk-dev/locations/global/collections/default_collection/dataStores/yvonne_1728691676574'
859+
)
823860
}
824861
}
825862
}]
@@ -835,7 +872,9 @@ async def test_vai_search_stream_async(client):
835872
'tools': [{
836873
'retrieval': {
837874
'vertex_ai_search': {
838-
'datastore': 'projects/vertex-sdk-dev/locations/global/collections/default_collection/dataStores/yvonne_1728691676574'
875+
'datastore': (
876+
'projects/vertex-sdk-dev/locations/global/collections/default_collection/dataStores/yvonne_1728691676574'
877+
)
839878
}
840879
}
841880
}]
@@ -855,7 +894,7 @@ async def divide_integers(a: int, b: int) -> int:
855894
contents='what is the result of 1000/2?',
856895
config={
857896
'tools': [divide_integers],
858-
'automatic_function_calling': {'ignore_call_history': True}
897+
'automatic_function_calling': {'ignore_call_history': True},
859898
},
860899
)
861900

@@ -868,13 +907,13 @@ async def divide_integers(a: int, b: int) -> int:
868907
return a // b
869908

870909
response = await client.aio.models.generate_content(
871-
model='gemini-1.5-flash',
872-
contents='what is the result of 1000/2?',
873-
config={
874-
'tools': [divide_integers],
875-
'automatic_function_calling': {'ignore_call_history': True}
876-
},
877-
)
910+
model='gemini-1.5-flash',
911+
contents='what is the result of 1000/2?',
912+
config={
913+
'tools': [divide_integers],
914+
'automatic_function_calling': {'ignore_call_history': True},
915+
},
916+
)
878917

879918
assert '500' in response.text
880919

@@ -889,7 +928,7 @@ def divide_integers(a: int, b: int) -> int:
889928
contents='what is the result of 1000/2?',
890929
config={
891930
'tools': [divide_integers],
892-
'automatic_function_calling': {'ignore_call_history': True}
931+
'automatic_function_calling': {'ignore_call_history': True},
893932
},
894933
)
895934

@@ -907,7 +946,12 @@ def mystery_function(a: int, b: int) -> int:
907946
config={'tools': [divide_integers]},
908947
)
909948
assert response.automatic_function_calling_history
910-
assert response.automatic_function_calling_history[-1].parts[0].function_response.response['error']
949+
assert (
950+
response.automatic_function_calling_history[-1]
951+
.parts[0]
952+
.function_response.response['error']
953+
)
954+
911955

912956
@pytest.mark.asyncio
913957
async def test_automatic_function_calling_async_float_without_decimal(client):
@@ -971,7 +1015,9 @@ async def get_current_weather_async(city: str) -> str:
9711015

9721016

9731017
@pytest.mark.asyncio
974-
async def test_automatic_function_calling_async_with_async_function_stream(client):
1018+
async def test_automatic_function_calling_async_with_async_function_stream(
1019+
client,
1020+
):
9751021
async def get_current_weather_async(city: str) -> str:
9761022
"""Returns the current weather in the city."""
9771023

@@ -1182,8 +1228,9 @@ def test_code_execution_tool(client):
11821228

11831229
assert response.executable_code
11841230
assert (
1185-
'prime' in response.code_execution_result.lower() or
1186-
'5117' in response.code_execution_result)
1231+
'prime' in response.code_execution_result.lower()
1232+
or '5117' in response.code_execution_result
1233+
)
11871234

11881235

11891236
def test_afc_logs_to_logger_instance(client, caplog):

0 commit comments

Comments
 (0)