Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ services:
- "./tests/system/test_apps/generating_app:/opt/splunk/etc/apps/generating_app"
- "./tests/system/test_apps/reporting_app:/opt/splunk/etc/apps/reporting_app"
- "./tests/system/test_apps/streaming_app:/opt/splunk/etc/apps/streaming_app"
- "./tests/system/test_apps/modularinput_app:/opt/splunk/etc/apps/modularinput_app"
- "./tests/system/test_apps/mcp_enabled_app:/opt/splunk/etc/apps/mcp_enabled_app"
- "./splunklib:/opt/splunk/etc/apps/eventing_app/bin/splunklib"
- "./splunklib:/opt/splunk/etc/apps/generating_app/bin/splunklib"
- "./splunklib:/opt/splunk/etc/apps/reporting_app/bin/splunklib"
- "./splunklib:/opt/splunk/etc/apps/streaming_app/bin/splunklib"
- "./splunklib:/opt/splunk/etc/apps/modularinput_app/bin/splunklib"
- "./splunklib:/opt/splunk/etc/apps/mcp_enabled_app/bin/splunklib"
17 changes: 14 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ name = "splunk-sdk"
dynamic = ["version"]
description = "Splunk Software Development Kit for Python"
readme = "README.md"
requires-python = ">=3.7"
requires-python = ">=3.10"
license = { text = "Apache-2.0" }
authors = [{ name = "Splunk, Inc.", email = "[email protected]" }]
keywords = ["splunk", "sdk"]
Expand All @@ -29,7 +29,12 @@ classifiers = [
"Topic :: Software Development :: Libraries :: Application Frameworks",
]

dependencies = ["python-dotenv>=0.21.1"]
dependencies = [
"fastmcp>=2.12.4",
"httpx>=0.28.1",
"mcp>=1.15.0",
"python-dotenv>=0.21.1",
]
optional-dependencies = { compat = ["six>=1.17.0"] }

[dependency-groups]
Expand All @@ -51,7 +56,12 @@ requires = ["setuptools"]
build-backend = "setuptools.build_meta"

[tool.setuptools]
packages = ["splunklib", "splunklib.modularinput", "splunklib.searchcommands"]
packages = [
"splunklib",
"splunklib.modularinput",
"splunklib.searchcommands",
"splunklib.mcp",
]

[tool.setuptools.dynamic]
version = { attr = "splunklib.__version__" }
Expand All @@ -66,3 +76,4 @@ select = [
"ANN", # flake8 type annotations
"RUF", # ruff-specific rules
]

22 changes: 13 additions & 9 deletions splunklib/binding.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,16 @@
from contextlib import contextmanager
from datetime import datetime
from functools import wraps
from io import BytesIO
from urllib import parse
from http import client
from http.cookies import SimpleCookie
from io import BytesIO
from urllib import parse
from xml.etree.ElementTree import XML, ParseError
from .data import record
from . import __version__

from splunklib.data import Record

from . import __version__
from .data import record

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -511,7 +513,7 @@ class Context:
:param headers: List of extra HTTP headers to send (optional).
:type headers: ``list`` of 2-tuples.
:param retries: Number of retries for each HTTP connection (optional, the default is 0).
NOTE: THIS MAY INCREASE THE NUMBER OF ROUNDTRIP CONNECTIONS
NOTE: THIS MAY INCREASE THE NUMBER OF ROUNDTRIP CONNECTIONS
TO THE SPLUNK SERVER AND BLOCK THE CURRENT THREAD WHILE RETRYING.
:type retries: ``int``
:param retryDelay: How long to wait between connection attempts if `retries` > 0 (optional, defaults to 10s).
Expand Down Expand Up @@ -653,7 +655,9 @@ def connect(self):

@_authentication
@_log_duration
def delete(self, path_segment, owner=None, app=None, sharing=None, **query):
def delete(
self, path_segment, owner=None, app=None, sharing=None, **query
) -> Record:
"""Performs a DELETE operation at the REST path segment with the given
namespace and query.

Expand Down Expand Up @@ -716,7 +720,7 @@ def delete(self, path_segment, owner=None, app=None, sharing=None, **query):
@_log_duration
def get(
self, path_segment, owner=None, app=None, headers=None, sharing=None, **query
):
) -> Record:
"""Performs a GET operation from the REST path segment with the given
namespace and query.

Expand Down Expand Up @@ -783,7 +787,7 @@ def get(
@_log_duration
def post(
self, path_segment, owner=None, app=None, sharing=None, headers=None, **query
):
) -> Record:
"""Performs a POST operation from the REST path segment with the given
namespace and query.

Expand Down Expand Up @@ -1357,7 +1361,7 @@ def get(self, url, headers=None, **kwargs):
url = url + UrlEncoded("?" + _encode(**kwargs), skip_encode=True)
return self.request(url, {"method": "GET", "headers": headers})

def post(self, url, headers=None, **kwargs):
def post(self, url, headers=None, **kwargs) -> Record:
"""Sends a POST request to a URL.

:param url: The URL.
Expand Down
15 changes: 10 additions & 5 deletions splunklib/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@
"""

import contextlib
import datetime
import json
import logging
import re
Expand All @@ -68,8 +67,9 @@
from time import sleep
from urllib import parse

from splunklib.data import Record

from . import data
from .data import record
from .binding import (
AuthenticationError,
Context,
Expand All @@ -80,6 +80,7 @@
_NoAuthenticationToken,
namespace,
)
from .data import record

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -808,7 +809,7 @@ class Endpoint:
:class:`Entity` (essentially HTTP GET and POST methods).
"""

def __init__(self, service, path):
def __init__(self, service: Service, path):
self.service = service
self.path = path

Expand All @@ -833,7 +834,9 @@ def get_api_version(self, path):

return api_version

def get(self, path_segment="", owner=None, app=None, sharing=None, **query):
def get(
self, path_segment="", owner=None, app=None, sharing=None, **query
) -> Record:
"""Performs a GET operation on the path segment relative to this endpoint.

This method is named to match the HTTP method. This method makes at least
Expand Down Expand Up @@ -916,7 +919,9 @@ def get(self, path_segment="", owner=None, app=None, sharing=None, **query):

return self.service.get(path, owner=owner, app=app, sharing=sharing, **query)

def post(self, path_segment="", owner=None, app=None, sharing=None, **query):
def post(
self, path_segment="", owner=None, app=None, sharing=None, **query
) -> Record:
"""Performs a POST operation on the path segment relative to this endpoint.

This method is named to match the HTTP method. This method makes at least
Expand Down
3 changes: 2 additions & 1 deletion splunklib/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
format, which is the format used by most of the REST API.
"""

from typing import Any
from xml.etree.ElementTree import XML

__all__ = ["load", "record"]
Expand Down Expand Up @@ -201,7 +202,7 @@ def load_value(element, nametable=None):


# A generic utility that enables "dot" access to dicts
class Record(dict):
class Record(dict[Any, Any]): # pyright: ignore[reportExplicitAny]
"""This generic utility class enables dot access to members of a Python
dictionary.

Expand Down
Empty file added splunklib/mcp/__init__.py
Empty file.
28 changes: 28 additions & 0 deletions splunklib/mcp/mcp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import httpx
from mcp.types import Tool as MCPTool

from splunklib.mcp.tools.models import AddTool, AddToolsRequest


async def send_mcp_registrations(
endpoint_url: str,
tool_registrations: list[MCPTool],
server_file_path: str,
):
async with httpx.AsyncClient() as client:
add_req = AddToolsRequest(
tools=[
AddTool(script_path=server_file_path, spec=tool)
for tool in tool_registrations
]
)

res = await client.post(endpoint_url, json=add_req.model_dump())
print(res.status_code)
print(res.text)


async def execute_tool(endpoint_url: str):
async with httpx.AsyncClient() as client:
res = await client.post(endpoint_url)
print(res.text)
62 changes: 62 additions & 0 deletions splunklib/mcp/tools/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from dataclasses import field
from typing import Any, Literal

from fastmcp.client import Client
from fastmcp.server.dependencies import get_context
from fastmcp.tools import Tool as FastMCPTool
from fastmcp.tools.tool import ToolResult
from mcp.types import Tool as MCPTool
from pydantic.main import BaseModel
from typing_extensions import override


class SplunkMeta(BaseModel):
permissions: list[str] = field(default=[])
tool_type: str = field(default="")
schema_version: str = field(default="")
execution_mode: str = field(default="")
execution_endpoint: str = field(default="")


class McpInputOutputSchema(BaseModel):
type: Literal["object"] = "object"
properties: dict[str, Any] = field(default_factory=lambda: {}) # pyright: ignore[reportExplicitAny]
required: list[str] = field(default_factory=lambda: [])


class AddTool(BaseModel):
script_path: str
spec: MCPTool


class AddToolsRequest(BaseModel):
tools: list[AddTool]


class DeleteToolsRequest(BaseModel):
tools: list[str]


class ProxiedTool(FastMCPTool):
script: str

@override
async def run(self, arguments: dict[str, Any]) -> ToolResult: # pyright: ignore[reportExplicitAny]
async def progress_handler(
progress: float, total: float | None, message: str | None
) -> None:
await get_context().report_progress(progress, total, message)

c = Client(transport=self.script)

async with c:
res = await c.call_tool(
self.name, arguments, progress_handler=progress_handler
)

# TODO: we are missing some fields ....
# res.is_error
# res.data
return ToolResult(
content=res.content, structured_content=res.structured_content
)
98 changes: 98 additions & 0 deletions splunklib/mcp/tools/registrations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import configparser
from dataclasses import asdict
from typing import Literal

from fastmcp.client import Client
from mcp.types import Tool as MCPTool

from splunklib.mcp.mcp import send_mcp_registrations
from splunklib.mcp.tools.models import (
McpInputOutputSchema,
SplunkMeta,
)

tool_reg_prefix = "app:mcp_tool"


def filter_sections(section_name: str) -> bool:
return section_name.startswith(tool_reg_prefix)


def match_input_schema(input: Literal["query_string"] | Literal["other"]):
"""Gets super messy :("""
match input:
case "query_string":
return {
"type": "object",
"properties": {
"query_string": {
"type": "string",
"description": "SPL2 query string",
}
},
}
case _:
raise NotImplementedError("We don't know what to put here lol")


def parse_ai_conf(file_path: str) -> list[MCPTool]:
config = configparser.ConfigParser()
all_sections_len = config.read(file_path)
if len(all_sections_len) == 0:
return []

tool_reg_sections: list[str] = list(filter(filter_sections, config.sections()))
if len(tool_reg_sections) == 0:
return []

ini_tools: list[Tool] = []
for reg_section in tool_reg_sections:
reg_section_data = config[reg_section]

name: str = reg_section.split(":")[2]
description = reg_section_data["description"]
# https://modelcontextprotocol.io/specification/2025-06-18/schema#tool
inputSchema = McpInputOutputSchema(properties={}, required=[])
outputSchema = McpInputOutputSchema(properties={}, required=[])
meta = SplunkMeta(
permissions=[
perm.strip()
for perm in reg_section_data["permissions"].strip().split(",")
],
tool_type="search",
schema_version=reg_section_data["schema_version"].strip(),
)

ini_tool = Tool(
name=name,
description=description,
inputSchema=asdict(inputSchema),
outputSchema=asdict(outputSchema),
_meta=asdict(meta),
)
ini_tools.append(ini_tool)

return ini_tools


async def get_mcp_tools(server_path: str) -> list[MCPTool]:
"""Connects to local MCP server to get tools registered with a @tool decorator"""
mcp_client = Client(server_path)

tools: list[MCPTool] = []
async with mcp_client:
tools = await mcp_client.list_tools()

return tools


async def register_tools_to_mcp_server(
server_file_path: str, endpoint_url: str
) -> None:
tool_registrations = await get_mcp_tools(server_file_path)

await send_mcp_registrations(
endpoint_url,
tool_registrations,
server_file_path,
)
Loading
Loading