|
| 1 | +# |
| 2 | +# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one |
| 3 | +# or more contributor license agreements. Licensed under the Elastic License 2.0; |
| 4 | +# you may not use this file except in compliance with the Elastic License 2.0. |
| 5 | +# |
| 6 | +import json |
| 7 | +import os |
| 8 | +from functools import cached_property |
| 9 | +from typing import Any, Awaitable, Callable |
| 10 | + |
| 11 | +import aiohttp |
| 12 | +from aiohttp import ClientResponseError |
| 13 | +from connectors_sdk.logger import logger |
| 14 | +from notion_client import APIResponseError, AsyncClient |
| 15 | + |
| 16 | +from connectors.utils import CancellableSleeps, RetryStrategy, retryable |
| 17 | + |
| 18 | +RETRIES = 3 |
| 19 | +RETRY_INTERVAL = 2 |
| 20 | +DEFAULT_RETRY_SECONDS = 30 |
| 21 | +BASE_URL = "https://api.notion.com" |
| 22 | +MAX_CONCURRENT_CLIENT_SUPPORT = 30 |
| 23 | + |
| 24 | + |
| 25 | +if "OVERRIDE_URL" in os.environ: |
| 26 | + BASE_URL = os.environ["OVERRIDE_URL"] |
| 27 | + |
| 28 | + |
| 29 | +class NotFound(Exception): |
| 30 | + pass |
| 31 | + |
| 32 | + |
| 33 | +class NotionClient: |
| 34 | + """Notion API client""" |
| 35 | + |
| 36 | + def __init__(self, configuration): |
| 37 | + self._sleeps = CancellableSleeps() |
| 38 | + self.configuration = configuration |
| 39 | + self._logger = logger |
| 40 | + self.notion_secret_key = self.configuration["notion_secret_key"] |
| 41 | + |
| 42 | + def set_logger(self, logger_): |
| 43 | + self._logger = logger_ |
| 44 | + |
| 45 | + @cached_property |
| 46 | + def _get_client(self): |
| 47 | + return AsyncClient( |
| 48 | + auth=self.notion_secret_key, |
| 49 | + base_url=BASE_URL, |
| 50 | + ) |
| 51 | + |
| 52 | + @cached_property |
| 53 | + def session(self): |
| 54 | + """Generate aiohttp client session. |
| 55 | +
|
| 56 | + Returns: |
| 57 | + aiohttp.ClientSession: An instance of Client Session |
| 58 | + """ |
| 59 | + connector = aiohttp.TCPConnector(limit=MAX_CONCURRENT_CLIENT_SUPPORT) |
| 60 | + |
| 61 | + return aiohttp.ClientSession( |
| 62 | + connector=connector, |
| 63 | + raise_for_status=True, |
| 64 | + ) |
| 65 | + |
| 66 | + @retryable( |
| 67 | + retries=RETRIES, |
| 68 | + interval=RETRY_INTERVAL, |
| 69 | + strategy=RetryStrategy.EXPONENTIAL_BACKOFF, |
| 70 | + skipped_exceptions=NotFound, |
| 71 | + ) |
| 72 | + async def get_via_session(self, url): |
| 73 | + self._logger.debug(f"Fetching data from url {url}") |
| 74 | + try: |
| 75 | + async with self.session.get(url=url) as response: |
| 76 | + yield response |
| 77 | + except ClientResponseError as e: |
| 78 | + if e.status == 429: |
| 79 | + retry_seconds = e.headers.get("Retry-After") or DEFAULT_RETRY_SECONDS |
| 80 | + self._logger.debug( |
| 81 | + f"Rate Limit reached: retry in {retry_seconds} seconds" |
| 82 | + ) |
| 83 | + await self._sleeps.sleep(retry_seconds) |
| 84 | + raise |
| 85 | + elif e.status == 404: |
| 86 | + raise NotFound from e |
| 87 | + else: |
| 88 | + raise |
| 89 | + |
| 90 | + @retryable( |
| 91 | + retries=RETRIES, |
| 92 | + interval=RETRY_INTERVAL, |
| 93 | + strategy=RetryStrategy.EXPONENTIAL_BACKOFF, |
| 94 | + skipped_exceptions=NotFound, |
| 95 | + ) |
| 96 | + async def fetch_results( |
| 97 | + self, function: Callable[..., Awaitable[Any]], next_cursor=None, **kwargs: Any |
| 98 | + ): |
| 99 | + try: |
| 100 | + return await function(start_cursor=next_cursor, **kwargs) |
| 101 | + except APIResponseError as exception: |
| 102 | + if exception.code == "rate_limited" or exception.status == 429: |
| 103 | + retry_after = ( |
| 104 | + exception.headers.get("retry-after") or DEFAULT_RETRY_SECONDS |
| 105 | + ) |
| 106 | + request_info = f"Request: {function.__name__} (next_cursor: {next_cursor}, kwargs: {kwargs})" |
| 107 | + self._logger.info( |
| 108 | + f"Connector will attempt to retry after {int(retry_after)} seconds. {request_info}" |
| 109 | + ) |
| 110 | + await self._sleeps.sleep(int(retry_after)) |
| 111 | + msg = "Rate limit exceeded." |
| 112 | + raise Exception(msg) from exception |
| 113 | + else: |
| 114 | + raise |
| 115 | + |
| 116 | + async def async_iterate_paginated_api( |
| 117 | + self, function: Callable[..., Awaitable[Any]], **kwargs: Any |
| 118 | + ): |
| 119 | + """Return an async iterator over the results of any paginated Notion API.""" |
| 120 | + next_cursor = kwargs.pop("start_cursor", None) |
| 121 | + while True: |
| 122 | + response = await self.fetch_results(function, next_cursor, **kwargs) |
| 123 | + if response: |
| 124 | + for result in response.get("results"): |
| 125 | + yield result |
| 126 | + next_cursor = response.get("next_cursor") |
| 127 | + if not response["has_more"] or next_cursor is None: |
| 128 | + return |
| 129 | + |
| 130 | + async def fetch_owner(self): |
| 131 | + """Fetch integration authorized owner""" |
| 132 | + await self._get_client.users.me() |
| 133 | + |
| 134 | + async def close(self): |
| 135 | + self._sleeps.cancel() |
| 136 | + await self._get_client.aclose() |
| 137 | + await self.session.close() |
| 138 | + del self._get_client |
| 139 | + del self.session |
| 140 | + |
| 141 | + async def fetch_users(self): |
| 142 | + """Iterate over user information retrieved from the API. |
| 143 | + Yields: |
| 144 | + dict: User document information excluding bots.""" |
| 145 | + async for user_document in self.async_iterate_paginated_api( |
| 146 | + self._get_client.users.list |
| 147 | + ): |
| 148 | + if user_document.get("type") != "bot": |
| 149 | + yield user_document |
| 150 | + |
| 151 | + async def fetch_child_blocks(self, block_id): |
| 152 | + """Fetch child blocks recursively for a given block ID. |
| 153 | + Args: |
| 154 | + block_id (str): The ID of the parent block. |
| 155 | + Yields: |
| 156 | + dict: Child block information.""" |
| 157 | + |
| 158 | + async def fetch_children_recursively(block): |
| 159 | + if block.get("has_children") is True: |
| 160 | + async for child_block in self.async_iterate_paginated_api( |
| 161 | + self._get_client.blocks.children.list, block_id=block.get("id") |
| 162 | + ): |
| 163 | + yield child_block |
| 164 | + |
| 165 | + async for grandchild in fetch_children_recursively(child_block): # pyright: ignore |
| 166 | + yield grandchild |
| 167 | + |
| 168 | + try: |
| 169 | + async for block in self.async_iterate_paginated_api( |
| 170 | + self._get_client.blocks.children.list, block_id=block_id |
| 171 | + ): |
| 172 | + if block.get("type") not in [ |
| 173 | + "child_database", |
| 174 | + "child_page", |
| 175 | + "unsupported", |
| 176 | + ]: |
| 177 | + yield block |
| 178 | + if block.get("has_children") is True: |
| 179 | + async for child in fetch_children_recursively(block): |
| 180 | + yield child |
| 181 | + if block.get("type") == "child_database": |
| 182 | + async for record in self.query_database(block.get("id")): |
| 183 | + yield record |
| 184 | + except APIResponseError as error: |
| 185 | + if error.code == "validation_error" and "external_object" in json.loads( |
| 186 | + error.body |
| 187 | + ).get("message"): |
| 188 | + self._logger.warning( |
| 189 | + f"Encountered external object with id: {block_id}. Skipping : {error}" |
| 190 | + ) |
| 191 | + elif error.code == "object_not_found": |
| 192 | + self._logger.warning(f"Object not found: {error}") |
| 193 | + else: |
| 194 | + raise |
| 195 | + |
| 196 | + async def fetch_by_query(self, query): |
| 197 | + async for document in self.async_iterate_paginated_api( |
| 198 | + self._get_client.search, **query |
| 199 | + ): |
| 200 | + yield document |
| 201 | + if query and query.get("filter", {}).get("value") == "database": |
| 202 | + async for database in self.query_database(document.get("id")): |
| 203 | + yield database |
| 204 | + |
| 205 | + async def fetch_comments(self, block_id): |
| 206 | + async for block_comment in self.async_iterate_paginated_api( |
| 207 | + self._get_client.comments.list, block_id=block_id |
| 208 | + ): |
| 209 | + yield block_comment |
| 210 | + |
| 211 | + async def query_database(self, database_id, body=None): |
| 212 | + if body is None: |
| 213 | + body = {} |
| 214 | + async for result in self.async_iterate_paginated_api( |
| 215 | + self._get_client.databases.query, database_id=database_id, **body |
| 216 | + ): |
| 217 | + yield result |
0 commit comments