diff --git a/README.md b/README.md index f6d59e6..79e6989 100644 --- a/README.md +++ b/README.md @@ -116,6 +116,7 @@ Below is a comprehensive table of all available tools, how to use them with an a | nova_reels | `agent.tool.nova_reels(action="create", text="A cinematic shot of mountains", s3_bucket="my-bucket")` | Create high-quality videos using Amazon Bedrock Nova Reel with configurable parameters via environment variables | | agent_core_memory | `agent.tool.agent_core_memory(action="record", content="Hello, I like vegetarian food")` | Store and retrieve memories with Amazon Bedrock Agent Core Memory service | | mem0_memory | `agent.tool.mem0_memory(action="store", content="Remember I like to play tennis", user_id="alex")` | Store user and agent memories across agent runs to provide personalized experience | +| bright_data | `agent.tool.bright_data(action="scrape_as_markdown", url="https://example.com")` | Web scraping, search queries, screenshot capture, and structured data extraction from websites and different data feeds| | memory | `agent.tool.memory(action="retrieve", query="product features")` | Store, retrieve, list, and manage documents in Amazon Bedrock Knowledge Bases with configurable parameters via environment variables | | environment | `agent.tool.environment(action="list", prefix="AWS_")` | Managing environment variables, configuration management | | generate_image_stability | `agent.tool.generate_image_stability(prompt="A tranquil pool")` | Creating images using Stability AI models | @@ -772,12 +773,21 @@ The Mem0 Memory Tool supports three different backend configurations: - If `OPENSEARCH_HOST` is set, the tool will use OpenSearch - If neither is set, the tool will default to FAISS (requires `faiss-cpu` package) - LLM configuration applies to all backend modes and allows customization of the language model used for memory processing + +#### Bright Data Tool + +| Environment Variable | Description | Default | +|----------------------|-------------|---------| +| BRIGHTDATA_API_KEY | Bright Data API Key | None | +| BRIGHTDATA_ZONE | Bright Data Web Unlocker Zone | web_unlocker1 | + #### Memory Tool | Environment Variable | Description | Default | |----------------------|-------------|---------| | MEMORY_DEFAULT_MAX_RESULTS | Default maximum results for list operations | 50 | | MEMORY_DEFAULT_MIN_SCORE | Default minimum relevance score for filtering results | 0.4 | + #### Nova Reels Tool | Environment Variable | Description | Default | diff --git a/src/strands_tools/bright_data.py b/src/strands_tools/bright_data.py new file mode 100644 index 0000000..8460c63 --- /dev/null +++ b/src/strands_tools/bright_data.py @@ -0,0 +1,508 @@ +""" +Tool for web scraping, searching, and data extraction using Bright Data for Strands Agents + +This module provides comprehensive web scraping and data extraction capabilities using +Bright Data as the backend. It handles all aspects of web scraping with a user-friendly +interface and proper error handling. + +Key Features: +------------ +1. Web Scraping: + • scrape_as_markdown: Scrape webpage content and return as Markdown + • get_screenshot: Take screenshots of webpages + • search_engine: Perform search queries using various search engines + • web_data_feed: Extract structured data from websites like LinkedIn, Amazon, Instagram, etc. + +2. Advanced Capabilities: + • Support for multiple search engines (Google, Bing, Yandex) + • Advanced search parameters including language, location, device type + • Extracting structured data from various websites + • Screenshot generation for web pages + +3. Error Handling: + • Graceful API error handling + • Clear error messages + • Timeout management for web_data_feed + +Setup Requirements: +------------------ +1. Create a Bright Data account +2. Create a Web Unlocker zone in your Bright Data control panel +3. Set environment variables in your .env file: + BRIGHTDATA_API_KEY=your_api_key_here # Required + BRIGHTDATA_ZONE=your_zone_name_here # Optional, defaults to "web_unlocker1" +4. DO NOT use Datacenter/Residential proxy zones - they will be blocked + +Example .env configuration: + BRIGHTDATA_API_KEY=brd_abc123xyz789 + BRIGHTDATA_ZONE=web_unlocker_12345 + +Usage Examples: +-------------- +```python +from strands import Agent +from strands_tools import bright_data + +agent = Agent(tools=[bright_data]) + +# Scrape webpage as markdown +agent.tool.bright_data( + action="scrape_as_markdown", + url="https://example.com" +) + +# Search using Google +agent.tool.bright_data( + action="search_engine", + query="climate change solutions", + engine="google", + country_code="us", + language="en" +) + +# Extract product data from Amazon +agent.tool.bright_data( + action="web_data_feed", + source_type="amazon_product", + url="https://www.amazon.com/product-url" +) +``` +""" + +import json +import logging +import os +import time +from typing import Dict, Optional +from urllib.parse import quote + +import requests +from rich.panel import Panel +from rich.text import Text +from strands import tool + +from strands_tools.utils import console_util + +logger = logging.getLogger(__name__) + +console = console_util.create() + + +class BrightDataClient: + """Client for interacting with Bright Data API.""" + + def __init__( + self, + api_key: Optional[str] = None, + zone: str = "web_unlocker1", + verbose: bool = False, + ) -> None: + """ + Initialize with API token and default zone. + + Args: + api_key (Optional[str]): Your Bright Data API token, defaults to BRIGHTDATA_API_KEY env var + zone (str): Bright Data zone name + verbose (bool): Print additional information about requests + """ + self.api_key = api_key or os.environ.get("BRIGHTDATA_API_KEY") + if not self.api_key: + raise ValueError( + "BRIGHTDATA_API_KEY environment variable is required but not set. " + "Please set it to your Bright Data API token or provide it as an argument." + ) + + self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} + self.zone = zone + self.verbose = verbose + self.endpoint = "https://api.brightdata.com/request" + + def make_request(self, payload: Dict) -> str: + """ + Make a request to Bright Data API. + + Args: + payload (Dict): Request payload + + Returns: + str: Response text + """ + if self.verbose: + print(f"[Bright Data] Request: {payload['url']}") + + response = requests.post(self.endpoint, headers=self.headers, data=json.dumps(payload)) + + if response.status_code != 200: + raise Exception(f"Failed to scrape: {response.status_code} - {response.text}") + + return response.text + + def scrape_as_markdown(self, url: str, zone: Optional[str] = None) -> str: + """ + Scrape a webpage and return content in Markdown format. + + Args: + url (str): URL to scrape + zone: Override default Web Unlocker zone name (optional). + Must be a Web Unlocker zone - datacenter/residential zones will fail. + Default: "web_unlocker" + + Returns: + str: Scraped content as Markdown + """ + payload = {"url": url, "zone": zone or self.zone, "format": "raw", "data_format": "markdown"} + + return self.make_request(payload) + + def get_screenshot(self, url: str, output_path: str, zone: Optional[str] = None) -> str: + """ + Take a screenshot of a webpage. + + Args: + url (str): URL to screenshot + output_path (str): Path to save the screenshot + zone (Optional[str]): Override default zone + + Returns: + str: Path to saved screenshot + """ + payload = {"url": url, "zone": zone or self.zone, "format": "raw", "data_format": "screenshot"} + + response = requests.post(self.endpoint, headers=self.headers, data=json.dumps(payload)) + + if response.status_code != 200: + raise Exception(f"Error {response.status_code}: {response.text}") + + with open(output_path, "wb") as f: + f.write(response.content) + + return output_path + + @staticmethod + def encode_query(query: str) -> str: + """URL encode a search query.""" + return quote(query) + + def search_engine( + self, + query: str, + engine: str = "google", + zone: Optional[str] = None, + language: Optional[str] = None, + country_code: Optional[str] = None, + search_type: Optional[str] = None, + start: Optional[int] = None, + num_results: Optional[int] = 10, + location: Optional[str] = None, + device: Optional[str] = None, + return_json: bool = False, + ) -> str: + """ + Search using Google, Bing, or Yandex with advanced parameters and return results in Markdown. + + Args: + query (str): Search query + engine (str): Search engine - 'google', 'bing', or 'yandex' + zone: Override default Web Unlocker zone name (optional). + Must be a Web Unlocker zone - datacenter/residential zones will fail. + Default: "web_unlocker" + + # Google SERP specific parameters + language (Optional[str]): Two-letter language code (hl parameter) + country_code (Optional[str]): Two-letter country code (gl parameter) + search_type (Optional[str]): Type of search (images, shopping, news, etc.) + start (Optional[int]): Results pagination offset (0=first page, 10=second page) + num_results (Optional[int]): Number of results to return (default 10) + location (Optional[str]): Location for search results (uule parameter) + device (Optional[str]): Device type (mobile, ios, android, ipad, android_tablet) + return_json (bool): Return parsed JSON instead of HTML/Markdown + + + Returns: + str: Search results as Markdown or JSON + """ + encoded_query = self.encode_query(query) + + base_urls = { + "google": f"https://www.google.com/search?q={encoded_query}", + "bing": f"https://www.bing.com/search?q={encoded_query}", + "yandex": f"https://yandex.com/search/?text={encoded_query}", + } + + if engine not in base_urls: + raise ValueError(f"Unsupported search engine: {engine}. Use 'google', 'bing', or 'yandex'") + + search_url = base_urls[engine] + + if engine == "google": + params = [] + + if language: + params.append(f"hl={language}") + + if country_code: + params.append(f"gl={country_code}") + + if search_type: + if search_type == "jobs": + params.append("ibp=htl;jobs") + else: + search_types = {"images": "isch", "shopping": "shop", "news": "nws"} + tbm_value = search_types.get(search_type, search_type) + params.append(f"tbm={tbm_value}") + + if start is not None: + params.append(f"start={start}") + + if num_results: + params.append(f"num={num_results}") + + if location: + params.append(f"uule={self.encode_query(location)}") + + if device: + device_value = "1" + + if device in ["ios", "iphone"]: + device_value = "ios" + elif device == "ipad": + device_value = "ios_tablet" + elif device == "android": + device_value = "android" + elif device == "android_tablet": + device_value = "android_tablet" + + params.append(f"brd_mobile={device_value}") + + if return_json: + params.append("brd_json=1") + + if params: + search_url += "&" + "&".join(params) + + payload = { + "url": search_url, + "zone": zone or self.zone, + "format": "raw", + "data_format": "markdown" if not return_json else "raw", + } + + return self.make_request(payload) + + def web_data_feed( + self, + source_type: str, + url: str, + num_of_reviews: Optional[int] = None, + timeout: int = 600, + polling_interval: int = 1, + ) -> Dict: + """ + Retrieve structured web data from various sources like LinkedIn, Amazon, Instagram, etc. + + Args: + source_type (str): Type of data source (e.g., 'linkedin_person_profile', 'amazon_product') + url (str): URL of the web resource to retrieve data from + num_of_reviews (Optional[int]): Number of reviews to retrieve (only for facebook_company_reviews) + timeout (int): Maximum time in seconds to wait for data retrieval + polling_interval (int): Time in seconds between polling attempts + + Returns: + Dict: Structured data from the requested source + """ + datasets = { + "amazon_product": "gd_l7q7dkf244hwjntr0", + "amazon_product_reviews": "gd_le8e811kzy4ggddlq", + "linkedin_person_profile": "gd_l1viktl72bvl7bjuj0", + "linkedin_company_profile": "gd_l1vikfnt1wgvvqz95w", + "zoominfo_company_profile": "gd_m0ci4a4ivx3j5l6nx", + "instagram_profiles": "gd_l1vikfch901nx3by4", + "instagram_posts": "gd_lk5ns7kz21pck8jpis", + "instagram_reels": "gd_lyclm20il4r5helnj", + "instagram_comments": "gd_ltppn085pokosxh13", + "facebook_posts": "gd_lyclm1571iy3mv57zw", + "facebook_marketplace_listings": "gd_lvt9iwuh6fbcwmx1a", + "facebook_company_reviews": "gd_m0dtqpiu1mbcyc2g86", + "x_posts": "gd_lwxkxvnf1cynvib9co", + "zillow_properties_listing": "gd_lfqkr8wm13ixtbd8f5", + "booking_hotel_listings": "gd_m5mbdl081229ln6t4a", + "youtube_videos": "gd_m5mbdl081229ln6t4a", + } + + if source_type not in datasets: + valid_sources = ", ".join(datasets.keys()) + raise ValueError(f"Invalid source_type: {source_type}. Valid options are: {valid_sources}") + + dataset_id = datasets[source_type] + + request_data = {"url": url} + if source_type == "facebook_company_reviews" and num_of_reviews is not None: + request_data["num_of_reviews"] = str(num_of_reviews) + + trigger_response = requests.post( + "https://api.brightdata.com/datasets/v3/trigger", + params={"dataset_id": dataset_id, "include_errors": True}, + headers=self.headers, + json=[request_data], + ) + + trigger_data = trigger_response.json() + if not trigger_data.get("snapshot_id"): + raise Exception("No snapshot ID returned from trigger request") + + snapshot_id = trigger_data["snapshot_id"] + if self.verbose: + print(f"[Bright Data] {source_type} triggered with snapshot ID: {snapshot_id}") + + attempts = 0 + max_attempts = timeout + + while attempts < max_attempts: + try: + snapshot_response = requests.get( + f"https://api.brightdata.com/datasets/v3/snapshot/{snapshot_id}", + params={"format": "json"}, + headers=self.headers, + ) + + snapshot_data = snapshot_response.json() + + if isinstance(snapshot_data, dict) and snapshot_data.get("status") == "running": + if self.verbose: + print( + f"[Bright Data] Snapshot not ready, polling again (attempt {attempts + 1}/{max_attempts})" + ) + attempts += 1 + time.sleep(polling_interval) + continue + + if self.verbose: + print(f"[Bright Data] Data received after {attempts + 1} attempts") + + return snapshot_data + + except Exception as e: + if self.verbose: + print(f"[Bright Data] Polling error: {e!s}") + attempts += 1 + time.sleep(polling_interval) + + raise TimeoutError(f"Timeout after {max_attempts} seconds waiting for {source_type} data") + + +@tool +def bright_data( + action: str, + url: Optional[str] = None, + output_path: Optional[str] = None, + zone: Optional[str] = None, + query: Optional[str] = None, + engine: str = "google", + language: Optional[str] = None, + country_code: Optional[str] = None, + search_type: Optional[str] = None, + start: Optional[int] = None, + num_results: int = 10, + location: Optional[str] = None, + device: Optional[str] = None, + return_json: bool = False, + source_type: Optional[str] = None, + num_of_reviews: Optional[int] = None, + timeout: int = 600, + polling_interval: int = 1, +) -> str: + """ + Web scraping and data extraction tool powered by Bright Data. + + This tool provides a comprehensive interface for web scraping and data extraction using + Bright Data, including scraping web pages as markdown, taking screenshots, performing + search queries, and extracting structured data from various websites. + + Args: + action: The action to perform (scrape_as_markdown, get_screenshot, search_engine, web_data_feed) + url: URL to scrape or extract data from (for scrape_as_markdown, get_screenshot, web_data_feed) + output_path: Path to save the screenshot (for get_screenshot) + zone: Web Unlocker zone name (optional). If not provided, uses BRIGHTDATA_ZONE environment + variable, or defaults to "web_unlocker1". Set BRIGHTDATA_ZONE in your .env file to + configure your specific Web Unlocker zone name (e.g., BRIGHTDATA_ZONE=web_unlocker_12345) + query: Search query (for search_engine) + engine: Search engine to use (google, bing, yandex, default: google) + language: Two-letter language code for search results (hl parameter for Google) + country_code: Two-letter country code for search results (gl parameter for Google) + search_type: Type of search (images, shopping, news, etc.) + start: Results pagination offset (0=first page, 10=second page) + num_results: Number of results to return (default: 10) + location: Location for search results (uule parameter) + device: Device type (mobile, ios, android, ipad, android_tablet) + return_json: Return parsed JSON instead of HTML/Markdown (default: False) + source_type: Type of data source for web_data_feed (e.g., 'linkedin_person_profile', 'amazon_product') + num_of_reviews: Number of reviews to retrieve (only for facebook_company_reviews) + timeout: Maximum time in seconds to wait for data retrieval (default: 600) + polling_interval: Time in seconds between polling attempts (default: 1) + + Returns: + str: Response content from the requested operation + """ + try: + if not action: + raise ValueError("action parameter is required") + + if zone is None: + zone = os.environ.get("BRIGHTDATA_ZONE", "web_unlocker1") + + client = BrightDataClient(verbose=True, zone=zone) + if action == "scrape_as_markdown": + if not url: + raise ValueError("url is required for scrape_as_markdown action") + return client.scrape_as_markdown(url, zone) + + elif action == "get_screenshot": + if not url: + raise ValueError("url is required for get_screenshot action") + if not output_path: + raise ValueError("output_path is required for get_screenshot action") + output_path_result = client.get_screenshot(url, output_path, zone) + return f"Screenshot saved to {output_path_result}" + + elif action == "search_engine": + if not query: + raise ValueError("query is required for search_engine action") + return client.search_engine( + query=query, + engine=engine, + zone=zone, + language=language, + country_code=country_code, + search_type=search_type, + start=start, + num_results=num_results, + location=location, + device=device, + return_json=return_json, + ) + + elif action == "web_data_feed": + if not url: + raise ValueError("url is required for web_data_feed action") + if not source_type: + raise ValueError("source_type is required for web_data_feed action") + data = client.web_data_feed( + source_type=source_type, + url=url, + num_of_reviews=num_of_reviews, + timeout=timeout, + polling_interval=polling_interval, + ) + return json.dumps(data, indent=2) + + else: + raise ValueError(f"Invalid action: {action}") + + except Exception as e: + error_panel = Panel( + Text(str(e), style="red"), + title="Bright Data Operation Error", + border_style="red", + ) + console.print(error_panel) + raise diff --git a/tests/test_bright_data.py b/tests/test_bright_data.py new file mode 100644 index 0000000..39ca92f --- /dev/null +++ b/tests/test_bright_data.py @@ -0,0 +1,312 @@ +""" +Tests for the Bright Data tool using the tool decorator interface. +""" + +import json +import os +from unittest.mock import MagicMock, patch + +import pytest +from strands import Agent +from strands_tools import bright_data +from strands_tools.bright_data import BrightDataClient + + +@pytest.fixture +def agent(): + """Create an agent with the bright_data tool loaded.""" + return Agent(tools=[bright_data]) + + +@pytest.fixture +def mock_bright_data_client(): + """Create a mock Bright Data client.""" + client = MagicMock(spec=BrightDataClient) + return client + + +@patch.dict(os.environ, {"BRIGHTDATA_API_KEY": "test_api_key"}) +@patch("strands_tools.bright_data.BrightDataClient") +def test_scrape_as_markdown(mock_bright_data_client_class, mock_bright_data_client): + """Test scrape_as_markdown functionality.""" + mock_bright_data_client_class.return_value = mock_bright_data_client + + markdown_content = "# Example Website\n\nThis is example content." + mock_bright_data_client.scrape_as_markdown.return_value = markdown_content + + result = bright_data.bright_data(action="scrape_as_markdown", url="https://example.com", zone="unblocker") + + assert result == markdown_content + mock_bright_data_client.scrape_as_markdown.assert_called_once_with("https://example.com", "unblocker") + + +@patch.dict(os.environ, {"BRIGHTDATA_API_KEY": "test_api_key"}) +@patch("strands_tools.bright_data.BrightDataClient") +def test_get_screenshot(mock_bright_data_client_class, mock_bright_data_client): + """Test get_screenshot functionality.""" + mock_bright_data_client_class.return_value = mock_bright_data_client + + mock_bright_data_client.get_screenshot.return_value = "/tmp/screenshot.png" + + result = bright_data.bright_data( + action="get_screenshot", url="https://example.com", output_path="/tmp/screenshot.png", zone="test_zone" + ) + + assert "Screenshot saved to /tmp/screenshot.png" in result + mock_bright_data_client.get_screenshot.assert_called_once_with( + "https://example.com", "/tmp/screenshot.png", "test_zone" + ) + + +@patch.dict(os.environ, {"BRIGHTDATA_API_KEY": "test_api_key"}) +@patch("strands_tools.bright_data.BrightDataClient") +def test_search_engine(mock_bright_data_client_class, mock_bright_data_client): + """Test search_engine functionality.""" + mock_bright_data_client_class.return_value = mock_bright_data_client + + search_results = "# Search Results\n\n1. Result 1\n2. Result 2" + mock_bright_data_client.search_engine.return_value = search_results + + result = bright_data.bright_data( + action="search_engine", + query="test query", + engine="google", + language="en", + country_code="us", + search_type="images", + start=0, + num_results=10, + location="New York", + device="mobile", + return_json=False, + zone="test_zone", + ) + + assert result == search_results + mock_bright_data_client.search_engine.assert_called_once() + call_kwargs = mock_bright_data_client.search_engine.call_args[1] + assert call_kwargs["query"] == "test query" + assert call_kwargs["engine"] == "google" + assert call_kwargs["language"] == "en" + assert call_kwargs["country_code"] == "us" + assert call_kwargs["search_type"] == "images" + assert call_kwargs["start"] == 0 + assert call_kwargs["num_results"] == 10 + assert call_kwargs["location"] == "New York" + assert call_kwargs["device"] == "mobile" + assert call_kwargs["return_json"] is False + assert call_kwargs["zone"] == "test_zone" + + +@patch.dict(os.environ, {"BRIGHTDATA_API_KEY": "test_api_key"}) +@patch("strands_tools.bright_data.BrightDataClient") +def test_web_data_feed(mock_bright_data_client_class, mock_bright_data_client): + """Test web_data_feed functionality.""" + mock_bright_data_client_class.return_value = mock_bright_data_client + + amazon_data = { + "title": "Test Product", + "price": "29.99", + "rating": 4.5, + "reviews_count": 1024, + } + mock_bright_data_client.web_data_feed.return_value = amazon_data + + result = bright_data.bright_data( + action="web_data_feed", + source_type="amazon_product", + url="https://www.amazon.com/product-url", + num_of_reviews=5, + timeout=300, + polling_interval=2, + ) + + assert json.loads(result) == amazon_data + mock_bright_data_client.web_data_feed.assert_called_once_with( + source_type="amazon_product", + url="https://www.amazon.com/product-url", + num_of_reviews=5, + timeout=300, + polling_interval=2, + ) + + +@patch.dict(os.environ, {"BRIGHTDATA_API_KEY": "test_api_key"}) +def test_missing_required_parameters(): + """Test missing required parameters for different actions.""" + + # Test missing URL for scrape_as_markdown + with pytest.raises(ValueError) as exc_info: + bright_data.bright_data(action="scrape_as_markdown") + assert "url is required for scrape_as_markdown action" in str(exc_info.value) + + # Test missing URL for get_screenshot + with pytest.raises(ValueError) as exc_info: + bright_data.bright_data(action="get_screenshot") + assert "url is required for get_screenshot action" in str(exc_info.value) + + # Test missing output_path for get_screenshot + with pytest.raises(ValueError) as exc_info: + bright_data.bright_data(action="get_screenshot", url="https://example.com") + assert "output_path is required for get_screenshot action" in str(exc_info.value) + + # Test missing query for search_engine + with pytest.raises(ValueError) as exc_info: + bright_data.bright_data(action="search_engine") + assert "query is required for search_engine action" in str(exc_info.value) + + # Test missing source_type for web_data_feed + with pytest.raises(ValueError) as exc_info: + bright_data.bright_data(action="web_data_feed", url="https://example.com") + assert "source_type is required for web_data_feed action" in str(exc_info.value) + + # Test missing URL for web_data_feed + with pytest.raises(ValueError) as exc_info: + bright_data.bright_data(action="web_data_feed", source_type="amazon_product") + assert "url is required for web_data_feed action" in str(exc_info.value) + + +@patch.dict(os.environ, {"BRIGHTDATA_API_KEY": "test_api_key"}) +def test_invalid_action(): + """Test invalid action.""" + with pytest.raises(ValueError) as exc_info: + bright_data.bright_data(action="invalid") + assert "Invalid action: invalid" in str(exc_info.value) + + +@patch.dict(os.environ, {}) +def test_missing_api_key(): + """Test missing Bright Data API key.""" + with pytest.raises(ValueError) as exc_info: + bright_data.bright_data(action="scrape_as_markdown", url="https://example.com") + assert "BRIGHTDATA_API_KEY environment variable is required" in str(exc_info.value) + + +def test_missing_action(): + """Test missing action parameter.""" + with patch.dict(os.environ, {"BRIGHTDATA_API_KEY": "test_api_key"}): + with pytest.raises(ValueError) as exc_info: + bright_data.bright_data(action="") + assert "action parameter is required" in str(exc_info.value) + + +@patch.dict(os.environ, {"BRIGHTDATA_API_KEY": "test_api_key"}) +def test_bright_data_client_methods(): + """Test BrightDataClient class methods directly.""" + client = BrightDataClient(api_key="test_api_key", zone="test_zone", verbose=True) + + assert client.api_key == "test_api_key" + assert client.zone == "test_zone" + assert client.verbose is True + + with patch("requests.post") as mock_post: + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.text = "Test content" + mock_post.return_value = mock_response + + payload = {"url": "https://example.com", "zone": "test_zone"} + result = client.make_request(payload) + assert result == "Test content" + mock_post.assert_called_with( + "https://api.brightdata.com/request", + headers={"Content-Type": "application/json", "Authorization": "Bearer test_api_key"}, + data=json.dumps(payload), + ) + + with patch.object(client, "make_request", return_value="# Markdown Content") as mock_make_request: + result = client.scrape_as_markdown("https://example.com") + assert result == "# Markdown Content" + mock_make_request.assert_called_with( + {"url": "https://example.com", "zone": "test_zone", "format": "raw", "data_format": "markdown"} + ) + + encoded = client.encode_query("test query") + assert encoded == "test%20query" + + +@patch.dict(os.environ, {"BRIGHTDATA_API_KEY": "test_api_key"}) +def test_bright_data_client_failed_request(): + """Test BrightDataClient handling of failed requests.""" + with patch("requests.post") as mock_post: + mock_response = MagicMock() + mock_response.status_code = 403 + mock_response.text = "Forbidden" + mock_post.return_value = mock_response + + client = BrightDataClient() + + payload = {"url": "https://example.com", "zone": "unlocker"} + + with pytest.raises(Exception) as excinfo: + client.make_request(payload) + + assert "Failed to scrape: 403 - Forbidden" in str(excinfo.value) + + +@patch.dict(os.environ, {"BRIGHTDATA_API_KEY": "test_api_key"}) +def test_web_data_feed_timeout(): + """Test web_data_feed timeout handling.""" + with patch("requests.post") as mock_post, patch("requests.get") as mock_get, patch("time.sleep"): + trigger_response = MagicMock() + trigger_response.status_code = 200 + trigger_response.json.return_value = {"snapshot_id": "test_snapshot_id"} + mock_post.return_value = trigger_response + + snapshot_response = MagicMock() + snapshot_response.status_code = 200 + snapshot_response.json.return_value = {"status": "running"} + mock_get.return_value = snapshot_response + + client = BrightDataClient(verbose=True) + + with pytest.raises(TimeoutError) as excinfo: + client.web_data_feed( + source_type="amazon_product", url="https://www.amazon.com/product-url", timeout=5, polling_interval=1 + ) + + assert "Timeout after 5 seconds waiting for amazon_product data" in str(excinfo.value) + + +@patch.dict(os.environ, {"BRIGHTDATA_API_KEY": "test_api_key"}) +def test_search_engine_with_defaults(): + """Test search_engine with default parameters.""" + with patch("strands_tools.bright_data.BrightDataClient") as mock_client_class: + mock_client = mock_client_class.return_value + mock_client.search_engine.return_value = "search results" + + result = bright_data.bright_data(action="search_engine", query="test query") + + assert result == "search results" + # Verify default values are used + call_kwargs = mock_client.search_engine.call_args[1] + assert call_kwargs["engine"] == "google" + assert call_kwargs["num_results"] == 10 + assert call_kwargs["return_json"] is False + + +@patch.dict(os.environ, {"BRIGHTDATA_API_KEY": "test_api_key"}) +def test_zone_environment_variable(): + """Test that BRIGHTDATA_ZONE environment variable is used when zone is not provided.""" + with patch.dict(os.environ, {"BRIGHTDATA_ZONE": "custom_zone"}): + with patch("strands_tools.bright_data.BrightDataClient") as mock_client_class: + mock_client = mock_client_class.return_value + mock_client.scrape_as_markdown.return_value = "content" + + bright_data.bright_data(action="scrape_as_markdown", url="https://example.com") + + # Verify the client was created with the custom zone + mock_client_class.assert_called_with(verbose=True, zone="custom_zone") + + +@patch.dict(os.environ, {"BRIGHTDATA_API_KEY": "test_api_key"}) +def test_zone_default_fallback(): + """Test that default zone is used when no zone is provided and no environment variable is set.""" + with patch("strands_tools.bright_data.BrightDataClient") as mock_client_class: + mock_client = mock_client_class.return_value + mock_client.scrape_as_markdown.return_value = "content" + + bright_data.bright_data(action="scrape_as_markdown", url="https://example.com") + + # Verify the client was created with the default zone + mock_client_class.assert_called_with(verbose=True, zone="web_unlocker1")