Skip to content

Commit 9a8e62c

Browse files
committed
Added Tool HuggingFace API
1 parent 6d1f661 commit 9a8e62c

File tree

3 files changed

+154
-0
lines changed

3 files changed

+154
-0
lines changed
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
from __future__ import annotations
2+
from langchain.tools import BaseTool
3+
4+
class HuggingFaceHubTool(BaseTool):
5+
"""Base tool for interacting with the Hugging Face Hub."""
6+
7+
api_client: "HfApi"
8+
task: str = ""
9+
top_k_results: int = 3
10+
11+
def __init__(self, **kwargs):
12+
super().__init__(**kwargs)
13+
try:
14+
from huggingface_hub import HfApi
15+
self.api_client = HfApi()
16+
except ImportError as e:
17+
raise ImportError(
18+
"huggingface_hub is not installed. Please install it with `pip install huggingface-hub`"
19+
) from e
20+
21+
def _run(self, query: str) -> str:
22+
"""Use the tool."""
23+
try:
24+
if self.task == "models":
25+
results_list = self.api_client.list_models(
26+
search=query, top_k=self.top_k_results
27+
)
28+
elif self.task == "datasets":
29+
results_list = self.api_client.list_datasets(
30+
search=query, top_k=self.top_k_results
31+
)
32+
else:
33+
return "Invalid task specified for the Hugging Face Hub tool."
34+
35+
if not results_list:
36+
return f"No {self.task} found on the Hugging Face Hub for '{query}'."
37+
38+
formatted_results = []
39+
for result in results_list:
40+
header = f"ID: {getattr(result, 'modelId', getattr(result, 'id', 'N/A'))}"
41+
author = f"Author: {getattr(result, 'author', 'N/A')}"
42+
tags = f"Tags: {', '.join(getattr(result, 'tags', []))}"
43+
formatted_results.append(f"{header}\n{author}\n{tags}")
44+
45+
return "\n\n---\n\n".join(formatted_results)
46+
47+
except Exception as e:
48+
return f"An error occurred: {e}"
49+
50+
async def _arun(self, query: str) -> str:
51+
"""Use the tool asynchronously."""
52+
import asyncio
53+
return await asyncio.get_running_loop().run_in_executor(
54+
None, self._run, query
55+
)
56+
57+
58+
class HuggingFaceModelSearchTool(HuggingFaceHubTool):
59+
"""Tool that searches for models on the Hugging Face Hub."""
60+
61+
name: str = "hugging_face_model_search"
62+
description: str = (
63+
"Use this tool to search for models on the Hugging Face Hub. "
64+
"The input should be a search query string. "
65+
"The output will be a formatted string with the top results, "
66+
"including their ID, author, and tags."
67+
)
68+
task: str = "models"
69+
70+
71+
class HuggingFaceDatasetSearchTool(HuggingFaceHubTool):
72+
"""Tool that searches for datasets on the Hugging Face Hub."""
73+
74+
name: str = "hugging_face_dataset_search"
75+
description: str = (
76+
"Use this tool to search for datasets on the Hugging Face Hub. "
77+
"The input should be a search query string. "
78+
"The output will be a formatted string with the top results, "
79+
"including their ID, author, and tags."
80+
)
81+
task: str = "datasets"

libs/community/pyproject.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,11 @@ version = "0.3.29"
2626
description = "Community contributed LangChain integrations."
2727
readme = "README.md"
2828

29+
[project.optional-dependencies]
30+
all = [
31+
"huggingface-hub",
32+
]
33+
2934
[project.urls]
3035
"Source Code" = "https://github.com/langchain-ai/langchain-community/tree/main/libs/community"
3136
"Release Notes" = "https://github.com/langchain-ai/langchain/releases?q=tag%3A%22langchain-community%3D%3D0%22&expanded=true"
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
from unittest.mock import MagicMock
2+
import pytest
3+
from langchain_community.tools.huggingface import (
4+
HuggingFaceModelSearchTool,
5+
HuggingFaceDatasetSearchTool,
6+
)
7+
8+
@pytest.fixture
9+
def mock_hf_api(mocker):
10+
"""Fixture to mock the HfApi client."""
11+
mock_api = MagicMock()
12+
# Mock the list_models and list_datasets methods
13+
mock_api.list_models.return_value = [
14+
MagicMock(modelId="gpt2", author="openai", tags=["text-generation"]),
15+
MagicMock(modelId="distilbert-base-uncased", author="distilbert", tags=["fill-mask"]),
16+
]
17+
mock_api.list_datasets.return_value = [
18+
MagicMock(id="squad", author="stanford", tags=["question-answering"]),
19+
MagicMock(id="imdb", author="stanford", tags=["text-classification"]),
20+
]
21+
# Patch the HfApi constructor to return our mock object
22+
mocker.patch(
23+
"langchain_community.tools.huggingface.HfApi", return_value=mock_api
24+
)
25+
return mock_api
26+
27+
28+
def test_huggingface_model_search_tool(mock_hf_api):
29+
"""Test the model search tool with mocked API."""
30+
tool = HuggingFaceModelSearchTool()
31+
result = tool.run("test query")
32+
33+
# Assert that the list_models method was called
34+
mock_hf_api.list_models.assert_called_once_with(search="test query", top_k=3)
35+
36+
# Assert that the output contains the mocked data
37+
assert "ID: gpt2" in result
38+
assert "Author: openai" in result
39+
assert "Tags: text-generation" in result
40+
assert "ID: distilbert-base-uncased" in result
41+
assert "---" in result
42+
43+
44+
def test_huggingface_dataset_search_tool(mock_hf_api):
45+
"""Test the dataset search tool with mocked API."""
46+
tool = HuggingFaceDatasetSearchTool()
47+
result = tool.run("another query")
48+
49+
# Assert that the list_datasets method was called
50+
mock_hf_api.list_datasets.assert_called_once_with(search="another query", top_k=3)
51+
52+
# Assert that the output contains the mocked data
53+
assert "ID: squad" in result
54+
assert "Author: stanford" in result
55+
assert "Tags: question-answering" in result
56+
assert "ID: imdb" in result
57+
assert "---" in result
58+
59+
60+
def test_huggingface_model_search_no_results(mock_hf_api):
61+
"""Test the model search tool when no results are found."""
62+
# Configure the mock to return an empty list for this test
63+
mock_hf_api.list_models.return_value = []
64+
65+
tool = HuggingFaceModelSearchTool()
66+
result = tool.run("empty query")
67+
68+
assert result == "No models found on the Hugging Face Hub for 'empty query'."

0 commit comments

Comments
 (0)