diff --git a/lldb/cmake/modules/LLDBConfig.cmake b/lldb/cmake/modules/LLDBConfig.cmake index 70e8db40328af..23ccae5e11fa8 100644 --- a/lldb/cmake/modules/LLDBConfig.cmake +++ b/lldb/cmake/modules/LLDBConfig.cmake @@ -77,6 +77,7 @@ add_optional_dependency(LLDB_ENABLE_FBSDVMCORE "Enable libfbsdvmcore support in option(LLDB_USE_ENTITLEMENTS "When codesigning, use entitlements if available" ON) option(LLDB_BUILD_FRAMEWORK "Build LLDB.framework (Darwin only)" OFF) +option(LLDB_ENABLE_PROTOCOL_SERVERS "Enable protocol servers (e.g. MCP) in LLDB" ON) option(LLDB_NO_INSTALL_DEFAULT_RPATH "Disable default RPATH settings in binaries" OFF) option(LLDB_USE_SYSTEM_DEBUGSERVER "Use the system's debugserver for testing (Darwin only)." OFF) option(LLDB_SKIP_STRIP "Whether to skip stripping of binaries when installing lldb." OFF) diff --git a/lldb/include/lldb/Core/Debugger.h b/lldb/include/lldb/Core/Debugger.h index 35a41e419c9bf..eb8aa314cf47b 100644 --- a/lldb/include/lldb/Core/Debugger.h +++ b/lldb/include/lldb/Core/Debugger.h @@ -376,7 +376,7 @@ class Debugger : public std::enable_shared_from_this, bool GetNotifyVoid() const; - const std::string &GetInstanceName() { return m_instance_name; } + const std::string &GetInstanceName() const { return m_instance_name; } bool GetShowInlineDiagnostics() const; diff --git a/lldb/include/lldb/Core/PluginManager.h b/lldb/include/lldb/Core/PluginManager.h index 0c988e5969538..96bf10fa48d38 100644 --- a/lldb/include/lldb/Core/PluginManager.h +++ b/lldb/include/lldb/Core/PluginManager.h @@ -255,6 +255,17 @@ class PluginManager { static void AutoCompleteProcessName(llvm::StringRef partial_name, CompletionRequest &request); + // Protocol + static bool RegisterPlugin(llvm::StringRef name, llvm::StringRef description, + ProtocolServerCreateInstance create_callback); + + static bool UnregisterPlugin(ProtocolServerCreateInstance create_callback); + + static llvm::StringRef GetProtocolServerPluginNameAtIndex(uint32_t idx); + + static ProtocolServerCreateInstance + GetProtocolCreateCallbackForPluginName(llvm::StringRef name); + // Register Type Provider static bool RegisterPlugin(llvm::StringRef name, llvm::StringRef description, RegisterTypeBuilderCreateInstance create_callback); diff --git a/lldb/include/lldb/Core/ProtocolServer.h b/lldb/include/lldb/Core/ProtocolServer.h new file mode 100644 index 0000000000000..937256c10aec1 --- /dev/null +++ b/lldb/include/lldb/Core/ProtocolServer.h @@ -0,0 +1,40 @@ +//===-- ProtocolServer.h --------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLDB_CORE_PROTOCOLSERVER_H +#define LLDB_CORE_PROTOCOLSERVER_H + +#include "lldb/Core/PluginInterface.h" +#include "lldb/Host/Socket.h" +#include "lldb/lldb-private-interfaces.h" + +namespace lldb_private { + +class ProtocolServer : public PluginInterface { +public: + ProtocolServer() = default; + virtual ~ProtocolServer() = default; + + static ProtocolServer *GetOrCreate(llvm::StringRef name); + + static std::vector GetSupportedProtocols(); + + struct Connection { + Socket::SocketProtocol protocol; + std::string name; + }; + + virtual llvm::Error Start(Connection connection) = 0; + virtual llvm::Error Stop() = 0; + + virtual Socket *GetSocket() const = 0; +}; + +} // namespace lldb_private + +#endif diff --git a/lldb/include/lldb/Host/JSONTransport.h b/lldb/include/lldb/Host/JSONTransport.h new file mode 100644 index 0000000000000..4087cdf2b42f7 --- /dev/null +++ b/lldb/include/lldb/Host/JSONTransport.h @@ -0,0 +1,141 @@ +//===-- JSONTransport.h ---------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Transport layer for encoding and decoding JSON protocol messages. +// +//===----------------------------------------------------------------------===// + +#ifndef LLDB_HOST_JSONTRANSPORT_H +#define LLDB_HOST_JSONTRANSPORT_H + +#include "lldb/lldb-forward.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/JSON.h" +#include +#include + +namespace lldb_private { + +class TransportEOFError : public llvm::ErrorInfo { +public: + static char ID; + + TransportEOFError() = default; + + void log(llvm::raw_ostream &OS) const override { + OS << "transport end of file reached"; + } + std::error_code convertToErrorCode() const override { + return llvm::inconvertibleErrorCode(); + } +}; + +class TransportTimeoutError : public llvm::ErrorInfo { +public: + static char ID; + + TransportTimeoutError() = default; + + void log(llvm::raw_ostream &OS) const override { + OS << "transport operation timed out"; + } + std::error_code convertToErrorCode() const override { + return std::make_error_code(std::errc::timed_out); + } +}; + +class TransportInvalidError : public llvm::ErrorInfo { +public: + static char ID; + + TransportInvalidError() = default; + + void log(llvm::raw_ostream &OS) const override { + OS << "transport IO object invalid"; + } + std::error_code convertToErrorCode() const override { + return std::make_error_code(std::errc::not_connected); + } +}; + +/// A transport class that uses JSON for communication. +class JSONTransport { +public: + JSONTransport(lldb::IOObjectSP input, lldb::IOObjectSP output); + virtual ~JSONTransport() = default; + + /// Transport is not copyable. + /// @{ + JSONTransport(const JSONTransport &rhs) = delete; + void operator=(const JSONTransport &rhs) = delete; + /// @} + + /// Writes a message to the output stream. + template llvm::Error Write(const T &t) { + const std::string message = llvm::formatv("{0}", toJSON(t)).str(); + return WriteImpl(message); + } + + /// Reads the next message from the input stream. + template + llvm::Expected Read(const std::chrono::microseconds &timeout) { + llvm::Expected message = ReadImpl(timeout); + if (!message) + return message.takeError(); + return llvm::json::parse(/*JSON=*/*message); + } + +protected: + virtual void Log(llvm::StringRef message); + + virtual llvm::Error WriteImpl(const std::string &message) = 0; + virtual llvm::Expected + ReadImpl(const std::chrono::microseconds &timeout) = 0; + + lldb::IOObjectSP m_input; + lldb::IOObjectSP m_output; +}; + +/// A transport class for JSON with a HTTP header. +class HTTPDelimitedJSONTransport : public JSONTransport { +public: + HTTPDelimitedJSONTransport(lldb::IOObjectSP input, lldb::IOObjectSP output) + : JSONTransport(input, output) {} + virtual ~HTTPDelimitedJSONTransport() = default; + +protected: + virtual llvm::Error WriteImpl(const std::string &message) override; + virtual llvm::Expected + ReadImpl(const std::chrono::microseconds &timeout) override; + + // FIXME: Support any header. + static constexpr llvm::StringLiteral kHeaderContentLength = + "Content-Length: "; + static constexpr llvm::StringLiteral kHeaderSeparator = "\r\n\r\n"; +}; + +/// A transport class for JSON RPC. +class JSONRPCTransport : public JSONTransport { +public: + JSONRPCTransport(lldb::IOObjectSP input, lldb::IOObjectSP output) + : JSONTransport(input, output) {} + virtual ~JSONRPCTransport() = default; + +protected: + virtual llvm::Error WriteImpl(const std::string &message) override; + virtual llvm::Expected + ReadImpl(const std::chrono::microseconds &timeout) override; + + static constexpr llvm::StringLiteral kMessageSeparator = "\n"; +}; + +} // namespace lldb_private + +#endif diff --git a/lldb/include/lldb/Host/PipeBase.h b/lldb/include/lldb/Host/PipeBase.h index d51d0cd54e036..ed8df6bf1e511 100644 --- a/lldb/include/lldb/Host/PipeBase.h +++ b/lldb/include/lldb/Host/PipeBase.h @@ -10,12 +10,11 @@ #ifndef LLDB_HOST_PIPEBASE_H #define LLDB_HOST_PIPEBASE_H -#include -#include - #include "lldb/Utility/Status.h" +#include "lldb/Utility/Timeout.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" +#include "llvm/Support/Error.h" namespace lldb_private { class PipeBase { @@ -32,10 +31,9 @@ class PipeBase { virtual Status OpenAsReader(llvm::StringRef name, bool child_process_inherit) = 0; - Status OpenAsWriter(llvm::StringRef name, bool child_process_inherit); - virtual Status - OpenAsWriterWithTimeout(llvm::StringRef name, bool child_process_inherit, - const std::chrono::microseconds &timeout) = 0; + virtual llvm::Error OpenAsWriter(llvm::StringRef name, + bool child_process_inherit, + const Timeout &timeout) = 0; virtual bool CanRead() const = 0; virtual bool CanWrite() const = 0; @@ -56,14 +54,13 @@ class PipeBase { // Delete named pipe. virtual Status Delete(llvm::StringRef name) = 0; - virtual Status WriteWithTimeout(const void *buf, size_t size, - const std::chrono::microseconds &timeout, - size_t &bytes_written) = 0; - Status Write(const void *buf, size_t size, size_t &bytes_written); - virtual Status ReadWithTimeout(void *buf, size_t size, - const std::chrono::microseconds &timeout, - size_t &bytes_read) = 0; - Status Read(void *buf, size_t size, size_t &bytes_read); + virtual llvm::Expected + Write(const void *buf, size_t size, + const Timeout &timeout = std::nullopt) = 0; + + virtual llvm::Expected + Read(void *buf, size_t size, + const Timeout &timeout = std::nullopt) = 0; }; } diff --git a/lldb/include/lldb/Host/Socket.h b/lldb/include/lldb/Host/Socket.h index 304a91bdf6741..59de18424d7c8 100644 --- a/lldb/include/lldb/Host/Socket.h +++ b/lldb/include/lldb/Host/Socket.h @@ -11,7 +11,9 @@ #include #include +#include +#include "lldb/Host/MainLoopBase.h" #include "lldb/lldb-private.h" #include "lldb/Host/SocketAddress.h" @@ -71,6 +73,11 @@ class Socket : public IOObject { ProtocolUnixAbstract }; + enum SocketMode { + ModeAccept, + ModeConnect, + }; + struct HostAndPort { std::string hostname; uint16_t port; @@ -80,6 +87,10 @@ class Socket : public IOObject { } }; + using ProtocolModePair = std::pair; + static std::optional + GetProtocolAndMode(llvm::StringRef scheme); + static const NativeSocket kInvalidSocketValue; ~Socket() override; @@ -97,7 +108,17 @@ class Socket : public IOObject { virtual Status Connect(llvm::StringRef name) = 0; virtual Status Listen(llvm::StringRef name, int backlog) = 0; - virtual Status Accept(Socket *&socket) = 0; + + // Use the provided main loop instance to accept new connections. The callback + // will be called (from MainLoop::Run) for each new connection. This function + // does not block. + virtual llvm::Expected> + Accept(MainLoopBase &loop, + std::function socket)> sock_cb) = 0; + + // Accept a single connection and "return" it in the pointer argument. This + // function blocks until the connection arrives. + virtual Status Accept(Socket *&socket); // Initialize a Tcp Socket object in listening mode. listen and accept are // implemented separately because the caller may wish to manipulate or query @@ -132,6 +153,11 @@ class Socket : public IOObject { // If this Socket is connected then return the URI used to connect. virtual std::string GetRemoteConnectionURI() const { return ""; }; + // If the Socket is listening then return the URI for clients to connect. + virtual std::vector GetListeningConnectionURI() const { + return {}; + } + protected: Socket(SocketProtocol protocol, bool should_close, bool m_child_process_inherit); diff --git a/lldb/include/lldb/Host/common/TCPSocket.h b/lldb/include/lldb/Host/common/TCPSocket.h index 78e80568e3996..a37ae843bed23 100644 --- a/lldb/include/lldb/Host/common/TCPSocket.h +++ b/lldb/include/lldb/Host/common/TCPSocket.h @@ -13,6 +13,8 @@ #include "lldb/Host/Socket.h" #include "lldb/Host/SocketAddress.h" #include +#include +#include namespace lldb_private { class TCPSocket : public Socket { @@ -42,16 +44,10 @@ class TCPSocket : public Socket { Status Connect(llvm::StringRef name) override; Status Listen(llvm::StringRef name, int backlog) override; - // Use the provided main loop instance to accept new connections. The callback - // will be called (from MainLoop::Run) for each new connection. This function - // does not block. + using Socket::Accept; llvm::Expected> Accept(MainLoopBase &loop, - std::function socket)> sock_cb); - - // Accept a single connection and "return" it in the pointer argument. This - // function blocks until the connection arrives. - Status Accept(Socket *&conn_socket) override; + std::function socket)> sock_cb) override; Status CreateSocket(int domain); @@ -59,6 +55,8 @@ class TCPSocket : public Socket { std::string GetRemoteConnectionURI() const override; + std::vector GetListeningConnectionURI() const override; + private: TCPSocket(NativeSocket socket, const TCPSocket &listen_socket); diff --git a/lldb/include/lldb/Host/common/UDPSocket.h b/lldb/include/lldb/Host/common/UDPSocket.h index bae707e345d87..7348010d02ada 100644 --- a/lldb/include/lldb/Host/common/UDPSocket.h +++ b/lldb/include/lldb/Host/common/UDPSocket.h @@ -27,7 +27,13 @@ class UDPSocket : public Socket { size_t Send(const void *buf, const size_t num_bytes) override; Status Connect(llvm::StringRef name) override; Status Listen(llvm::StringRef name, int backlog) override; - Status Accept(Socket *&socket) override; + + llvm::Expected> + Accept(MainLoopBase &loop, + std::function socket)> sock_cb) override { + return llvm::errorCodeToError( + std::make_error_code(std::errc::operation_not_supported)); + } SocketAddress m_sockaddr; }; diff --git a/lldb/include/lldb/Host/posix/DomainSocket.h b/lldb/include/lldb/Host/posix/DomainSocket.h index 35c33811f60de..3a7fb16d3fd75 100644 --- a/lldb/include/lldb/Host/posix/DomainSocket.h +++ b/lldb/include/lldb/Host/posix/DomainSocket.h @@ -10,18 +10,28 @@ #define LLDB_HOST_POSIX_DOMAINSOCKET_H #include "lldb/Host/Socket.h" +#include +#include namespace lldb_private { class DomainSocket : public Socket { public: + DomainSocket(NativeSocket socket, bool should_close, + bool child_processes_inherit); DomainSocket(bool should_close, bool child_processes_inherit); Status Connect(llvm::StringRef name) override; Status Listen(llvm::StringRef name, int backlog) override; - Status Accept(Socket *&socket) override; + + using Socket::Accept; + llvm::Expected> + Accept(MainLoopBase &loop, + std::function socket)> sock_cb) override; std::string GetRemoteConnectionURI() const override; + std::vector GetListeningConnectionURI() const override; + protected: DomainSocket(SocketProtocol protocol, bool child_processes_inherit); diff --git a/lldb/include/lldb/Host/posix/PipePosix.h b/lldb/include/lldb/Host/posix/PipePosix.h index 2e291160817c4..effd33fba7eb0 100644 --- a/lldb/include/lldb/Host/posix/PipePosix.h +++ b/lldb/include/lldb/Host/posix/PipePosix.h @@ -8,6 +8,7 @@ #ifndef LLDB_HOST_POSIX_PIPEPOSIX_H #define LLDB_HOST_POSIX_PIPEPOSIX_H + #include "lldb/Host/PipeBase.h" #include @@ -38,9 +39,8 @@ class PipePosix : public PipeBase { llvm::SmallVectorImpl &name) override; Status OpenAsReader(llvm::StringRef name, bool child_process_inherit) override; - Status - OpenAsWriterWithTimeout(llvm::StringRef name, bool child_process_inherit, - const std::chrono::microseconds &timeout) override; + llvm::Error OpenAsWriter(llvm::StringRef name, bool child_process_inherit, + const Timeout &timeout) override; bool CanRead() const override; bool CanWrite() const override; @@ -64,12 +64,13 @@ class PipePosix : public PipeBase { Status Delete(llvm::StringRef name) override; - Status WriteWithTimeout(const void *buf, size_t size, - const std::chrono::microseconds &timeout, - size_t &bytes_written) override; - Status ReadWithTimeout(void *buf, size_t size, - const std::chrono::microseconds &timeout, - size_t &bytes_read) override; + llvm::Expected + Write(const void *buf, size_t size, + const Timeout &timeout = std::nullopt) override; + + llvm::Expected + Read(void *buf, size_t size, + const Timeout &timeout = std::nullopt) override; private: bool CanReadUnlocked() const; diff --git a/lldb/include/lldb/Host/windows/PipeWindows.h b/lldb/include/lldb/Host/windows/PipeWindows.h index e28d104cc60ec..9cf591a2d4629 100644 --- a/lldb/include/lldb/Host/windows/PipeWindows.h +++ b/lldb/include/lldb/Host/windows/PipeWindows.h @@ -38,9 +38,8 @@ class PipeWindows : public PipeBase { llvm::SmallVectorImpl &name) override; Status OpenAsReader(llvm::StringRef name, bool child_process_inherit) override; - Status - OpenAsWriterWithTimeout(llvm::StringRef name, bool child_process_inherit, - const std::chrono::microseconds &timeout) override; + llvm::Error OpenAsWriter(llvm::StringRef name, bool child_process_inherit, + const Timeout &timeout) override; bool CanRead() const override; bool CanWrite() const override; @@ -59,12 +58,13 @@ class PipeWindows : public PipeBase { Status Delete(llvm::StringRef name) override; - Status WriteWithTimeout(const void *buf, size_t size, - const std::chrono::microseconds &timeout, - size_t &bytes_written) override; - Status ReadWithTimeout(void *buf, size_t size, - const std::chrono::microseconds &timeout, - size_t &bytes_read) override; + llvm::Expected + Write(const void *buf, size_t size, + const Timeout &timeout = std::nullopt) override; + + llvm::Expected + Read(void *buf, size_t size, + const Timeout &timeout = std::nullopt) override; // PipeWindows specific methods. These allow access to the underlying OS // handle. diff --git a/lldb/include/lldb/Interpreter/CommandOptionArgumentTable.h b/lldb/include/lldb/Interpreter/CommandOptionArgumentTable.h index 323f519ede053..8fb3e9e95c83d 100644 --- a/lldb/include/lldb/Interpreter/CommandOptionArgumentTable.h +++ b/lldb/include/lldb/Interpreter/CommandOptionArgumentTable.h @@ -337,6 +337,7 @@ static constexpr CommandObject::ArgumentTableEntry g_argument_table[] = { { lldb::eArgTypeModule, "module", lldb::CompletionType::eModuleCompletion, {}, { nullptr, false }, "The name of a module loaded into the current target." }, { lldb::eArgTypeCPUName, "cpu-name", lldb::CompletionType::eNoCompletion, {}, { nullptr, false }, "The name of a CPU." }, { lldb::eArgTypeCPUFeatures, "cpu-features", lldb::CompletionType::eNoCompletion, {}, { nullptr, false }, "The CPU feature string." }, + { lldb::eArgTypeProtocol, "protocol", lldb::CompletionType::eNoCompletion, {}, { nullptr, false }, "The name of the protocol." }, // clang-format on }; diff --git a/lldb/include/lldb/Target/Target.h b/lldb/include/lldb/Target/Target.h index 50ebcc5a77946..79df47fec620e 100644 --- a/lldb/include/lldb/Target/Target.h +++ b/lldb/include/lldb/Target/Target.h @@ -1176,7 +1176,7 @@ class Target : public std::enable_shared_from_this, Architecture *GetArchitecturePlugin() const { return m_arch.GetPlugin(); } - Debugger &GetDebugger() { return m_debugger; } + Debugger &GetDebugger() const { return m_debugger; } size_t ReadMemoryFromFileCache(const Address &addr, void *dst, size_t dst_len, Status &error); diff --git a/lldb/include/lldb/lldb-enumerations.h b/lldb/include/lldb/lldb-enumerations.h index 882640eccc3d2..42ec593e2ee42 100644 --- a/lldb/include/lldb/lldb-enumerations.h +++ b/lldb/include/lldb/lldb-enumerations.h @@ -669,6 +669,7 @@ enum CommandArgumentType { eArgTypeModule, eArgTypeCPUName, eArgTypeCPUFeatures, + eArgTypeProtocol, eArgTypeLastArg // Always keep this entry as the last entry in this // enumeration!! }; diff --git a/lldb/include/lldb/lldb-forward.h b/lldb/include/lldb/lldb-forward.h index a3550f3fe60ff..d0e7d5e8e2120 100644 --- a/lldb/include/lldb/lldb-forward.h +++ b/lldb/include/lldb/lldb-forward.h @@ -164,13 +164,13 @@ class PersistentExpressionState; class Platform; class Process; class ProcessAttachInfo; -class ProcessLaunchInfo; class ProcessInfo; class ProcessInstanceInfo; class ProcessInstanceInfoMatch; class ProcessLaunchInfo; class ProcessModID; class Property; +class ProtocolServer; class Queue; class QueueImpl; class QueueItem; @@ -389,6 +389,7 @@ typedef std::shared_ptr PlatformSP; typedef std::shared_ptr ProcessSP; typedef std::shared_ptr ProcessAttachInfoSP; typedef std::shared_ptr ProcessLaunchInfoSP; +typedef std::unique_ptr ProtocolServerUP; typedef std::weak_ptr ProcessWP; typedef std::shared_ptr RegisterCheckpointSP; typedef std::shared_ptr RegisterContextSP; diff --git a/lldb/include/lldb/lldb-private-interfaces.h b/lldb/include/lldb/lldb-private-interfaces.h index cd5ccc44324c3..6511269f32f37 100644 --- a/lldb/include/lldb/lldb-private-interfaces.h +++ b/lldb/include/lldb/lldb-private-interfaces.h @@ -82,6 +82,7 @@ typedef lldb::PlatformSP (*PlatformCreateInstance)(bool force, typedef lldb::ProcessSP (*ProcessCreateInstance)( lldb::TargetSP target_sp, lldb::ListenerSP listener_sp, const FileSpec *crash_file_path, bool can_connect); +typedef lldb::ProtocolServerUP (*ProtocolServerCreateInstance)(); typedef lldb::RegisterTypeBuilderSP (*RegisterTypeBuilderCreateInstance)( Target &target); typedef lldb::ScriptInterpreterSP (*ScriptInterpreterCreateInstance)( diff --git a/lldb/source/Commands/CMakeLists.txt b/lldb/source/Commands/CMakeLists.txt index 186d778305a4e..fab0e303d8b10 100644 --- a/lldb/source/Commands/CMakeLists.txt +++ b/lldb/source/Commands/CMakeLists.txt @@ -28,6 +28,7 @@ add_lldb_library(lldbCommands NO_PLUGIN_DEPENDENCIES CommandObjectPlatform.cpp CommandObjectPlugin.cpp CommandObjectProcess.cpp + CommandObjectProtocolServer.cpp CommandObjectQuit.cpp CommandObjectRegexCommand.cpp CommandObjectRegister.cpp diff --git a/lldb/source/Commands/CommandObjectProtocolServer.cpp b/lldb/source/Commands/CommandObjectProtocolServer.cpp new file mode 100644 index 0000000000000..38d93cabf8c04 --- /dev/null +++ b/lldb/source/Commands/CommandObjectProtocolServer.cpp @@ -0,0 +1,143 @@ +//===-- CommandObjectProtocolServer.cpp +//----------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "CommandObjectProtocolServer.h" +#include "lldb/Core/PluginManager.h" +#include "lldb/Core/ProtocolServer.h" +#include "lldb/Host/Socket.h" +#include "lldb/Interpreter/CommandInterpreter.h" +#include "lldb/Interpreter/CommandReturnObject.h" +#include "lldb/Utility/UriParser.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/FormatAdapters.h" + +using namespace llvm; +using namespace lldb; +using namespace lldb_private; + +#define LLDB_OPTIONS_mcp +#include "CommandOptions.inc" + +class CommandObjectProtocolServerStart : public CommandObjectParsed { +public: + CommandObjectProtocolServerStart(CommandInterpreter &interpreter) + : CommandObjectParsed(interpreter, "protocol-server start", + "start protocol server", + "protocol-server start ") { + AddSimpleArgumentList(lldb::eArgTypeProtocol, eArgRepeatPlain); + AddSimpleArgumentList(lldb::eArgTypeConnectURL, eArgRepeatPlain); + } + + ~CommandObjectProtocolServerStart() override = default; + +protected: + void DoExecute(Args &args, CommandReturnObject &result) override { + if (args.GetArgumentCount() < 1) { + result.AppendError("no protocol specified"); + return; + } + + llvm::StringRef protocol = args.GetArgumentAtIndex(0); + ProtocolServer *server = ProtocolServer::GetOrCreate(protocol); + if (!server) { + result.AppendErrorWithFormatv( + "unsupported protocol: {0}. Supported protocols are: {1}", protocol, + llvm::join(ProtocolServer::GetSupportedProtocols(), ", ")); + return; + } + + if (args.GetArgumentCount() < 2) { + result.AppendError("no connection specified"); + return; + } + llvm::StringRef connection_uri = args.GetArgumentAtIndex(1); + + const char *connection_error = + "unsupported connection specifier, expected 'accept:///path' or " + "'listen://[host]:port', got '{0}'."; + auto uri = lldb_private::URI::Parse(connection_uri); + if (!uri) { + result.AppendErrorWithFormatv(connection_error, connection_uri); + return; + } + + std::optional protocol_and_mode = + Socket::GetProtocolAndMode(uri->scheme); + if (!protocol_and_mode || protocol_and_mode->second != Socket::ModeAccept) { + result.AppendErrorWithFormatv(connection_error, connection_uri); + return; + } + + ProtocolServer::Connection connection; + connection.protocol = protocol_and_mode->first; + connection.name = + formatv("[{0}]:{1}", uri->hostname.empty() ? "0.0.0.0" : uri->hostname, + uri->port.value_or(0)); + + if (llvm::Error error = server->Start(connection)) { + result.AppendErrorWithFormatv("{0}", llvm::fmt_consume(std::move(error))); + return; + } + + if (Socket *socket = server->GetSocket()) { + std::string address = + llvm::join(socket->GetListeningConnectionURI(), ", "); + result.AppendMessageWithFormatv( + "{0} server started with connection listeners: {1}", protocol, + address); + } + } +}; + +class CommandObjectProtocolServerStop : public CommandObjectParsed { +public: + CommandObjectProtocolServerStop(CommandInterpreter &interpreter) + : CommandObjectParsed(interpreter, "protocol-server stop", + "stop protocol server", + "protocol-server stop ") { + AddSimpleArgumentList(lldb::eArgTypeProtocol, eArgRepeatPlain); + } + + ~CommandObjectProtocolServerStop() override = default; + +protected: + void DoExecute(Args &args, CommandReturnObject &result) override { + if (args.GetArgumentCount() < 1) { + result.AppendError("no protocol specified"); + return; + } + + llvm::StringRef protocol = args.GetArgumentAtIndex(0); + ProtocolServer *server = ProtocolServer::GetOrCreate(protocol); + if (!server) { + result.AppendErrorWithFormatv( + "unsupported protocol: {0}. Supported protocols are: {1}", protocol, + llvm::join(ProtocolServer::GetSupportedProtocols(), ", ")); + return; + } + + if (llvm::Error error = server->Stop()) { + result.AppendErrorWithFormatv("{0}", llvm::fmt_consume(std::move(error))); + return; + } + } +}; + +CommandObjectProtocolServer::CommandObjectProtocolServer( + CommandInterpreter &interpreter) + : CommandObjectMultiword(interpreter, "protocol-server", + "Start and stop a protocol server.", + "protocol-server") { + LoadSubCommand("start", CommandObjectSP(new CommandObjectProtocolServerStart( + interpreter))); + LoadSubCommand("stop", CommandObjectSP( + new CommandObjectProtocolServerStop(interpreter))); +} + +CommandObjectProtocolServer::~CommandObjectProtocolServer() = default; diff --git a/lldb/source/Commands/CommandObjectProtocolServer.h b/lldb/source/Commands/CommandObjectProtocolServer.h new file mode 100644 index 0000000000000..3591216b014cb --- /dev/null +++ b/lldb/source/Commands/CommandObjectProtocolServer.h @@ -0,0 +1,25 @@ +//===-- CommandObjectProtocolServer.h +//------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLDB_SOURCE_COMMANDS_COMMANDOBJECTPROTOCOLSERVER_H +#define LLDB_SOURCE_COMMANDS_COMMANDOBJECTPROTOCOLSERVER_H + +#include "lldb/Interpreter/CommandObjectMultiword.h" + +namespace lldb_private { + +class CommandObjectProtocolServer : public CommandObjectMultiword { +public: + CommandObjectProtocolServer(CommandInterpreter &interpreter); + ~CommandObjectProtocolServer() override; +}; + +} // namespace lldb_private + +#endif // LLDB_SOURCE_COMMANDS_COMMANDOBJECTMCP_H diff --git a/lldb/source/Core/CMakeLists.txt b/lldb/source/Core/CMakeLists.txt index c6bb3cded801a..e15bff774e02f 100644 --- a/lldb/source/Core/CMakeLists.txt +++ b/lldb/source/Core/CMakeLists.txt @@ -48,6 +48,7 @@ add_lldb_library(lldbCore Opcode.cpp PluginManager.cpp Progress.cpp + ProtocolServer.cpp Statusline.cpp RichManglingContext.cpp SearchFilter.cpp diff --git a/lldb/source/Core/Debugger.cpp b/lldb/source/Core/Debugger.cpp index 0efc9d9a4482f..bcafdb083ef3e 100644 --- a/lldb/source/Core/Debugger.cpp +++ b/lldb/source/Core/Debugger.cpp @@ -16,6 +16,7 @@ #include "lldb/Core/ModuleSpec.h" #include "lldb/Core/PluginManager.h" #include "lldb/Core/Progress.h" +#include "lldb/Core/ProtocolServer.h" #include "lldb/Core/StreamAsynchronousIO.h" #include "lldb/DataFormatters/DataVisualization.h" #include "lldb/Expression/REPL.h" diff --git a/lldb/source/Core/PluginManager.cpp b/lldb/source/Core/PluginManager.cpp index 8a19684d63f28..ed93f7dee6597 100644 --- a/lldb/source/Core/PluginManager.cpp +++ b/lldb/source/Core/PluginManager.cpp @@ -905,6 +905,38 @@ void PluginManager::AutoCompleteProcessName(llvm::StringRef name, } } +#pragma mark ProtocolServer + +typedef PluginInstance ProtocolServerInstance; +typedef PluginInstances ProtocolServerInstances; + +static ProtocolServerInstances &GetProtocolServerInstances() { + static ProtocolServerInstances g_instances; + return g_instances; +} + +bool PluginManager::RegisterPlugin( + llvm::StringRef name, llvm::StringRef description, + ProtocolServerCreateInstance create_callback) { + return GetProtocolServerInstances().RegisterPlugin(name, description, + create_callback); +} + +bool PluginManager::UnregisterPlugin( + ProtocolServerCreateInstance create_callback) { + return GetProtocolServerInstances().UnregisterPlugin(create_callback); +} + +llvm::StringRef +PluginManager::GetProtocolServerPluginNameAtIndex(uint32_t idx) { + return GetProtocolServerInstances().GetNameAtIndex(idx); +} + +ProtocolServerCreateInstance +PluginManager::GetProtocolCreateCallbackForPluginName(llvm::StringRef name) { + return GetProtocolServerInstances().GetCallbackForName(name); +} + #pragma mark RegisterTypeBuilder struct RegisterTypeBuilderInstance diff --git a/lldb/source/Core/ProtocolServer.cpp b/lldb/source/Core/ProtocolServer.cpp new file mode 100644 index 0000000000000..41636cdacdecc --- /dev/null +++ b/lldb/source/Core/ProtocolServer.cpp @@ -0,0 +1,47 @@ +//===-- ProtocolServer.cpp ------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "lldb/Core/ProtocolServer.h" +#include "lldb/Core/PluginManager.h" + +using namespace lldb_private; +using namespace lldb; + +ProtocolServer *ProtocolServer::GetOrCreate(llvm::StringRef name) { + static std::mutex g_mutex; + static llvm::StringMap g_protocol_server_instances; + + std::lock_guard guard(g_mutex); + + auto it = g_protocol_server_instances.find(name); + if (it != g_protocol_server_instances.end()) + return it->second.get(); + + if (ProtocolServerCreateInstance create_callback = + PluginManager::GetProtocolCreateCallbackForPluginName(name)) { + auto pair = + g_protocol_server_instances.try_emplace(name, create_callback()); + return pair.first->second.get(); + } + + return nullptr; +} + +std::vector ProtocolServer::GetSupportedProtocols() { + std::vector supported_protocols; + size_t i = 0; + + for (llvm::StringRef protocol_name = + PluginManager::GetProtocolServerPluginNameAtIndex(i++); + !protocol_name.empty(); + protocol_name = PluginManager::GetProtocolServerPluginNameAtIndex(i++)) { + supported_protocols.push_back(protocol_name); + } + + return supported_protocols; +} diff --git a/lldb/source/Host/CMakeLists.txt b/lldb/source/Host/CMakeLists.txt index 8b96bb1451fce..e60e0860a90ca 100644 --- a/lldb/source/Host/CMakeLists.txt +++ b/lldb/source/Host/CMakeLists.txt @@ -24,8 +24,9 @@ add_host_subdirectory(common common/HostNativeThreadBase.cpp common/HostProcess.cpp common/HostThread.cpp - common/LockFileBase.cpp + common/JSONTransport.cpp common/LZMA.cpp + common/LockFileBase.cpp common/MainLoopBase.cpp common/MonitoringProcessLauncher.cpp common/NativeProcessProtocol.cpp diff --git a/lldb/source/Host/common/JSONTransport.cpp b/lldb/source/Host/common/JSONTransport.cpp new file mode 100644 index 0000000000000..1a0851d5c4365 --- /dev/null +++ b/lldb/source/Host/common/JSONTransport.cpp @@ -0,0 +1,176 @@ +//===-- JSONTransport.cpp -------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "lldb/Host/JSONTransport.h" +#include "lldb/Utility/IOObject.h" +#include "lldb/Utility/LLDBLog.h" +#include "lldb/Utility/Log.h" +#include "lldb/Utility/SelectHelper.h" +#include "lldb/Utility/Status.h" +#include "lldb/lldb-forward.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/raw_ostream.h" +#include +#include +#include + +using namespace llvm; +using namespace lldb; +using namespace lldb_private; + +/// ReadFull attempts to read the specified number of bytes. If EOF is +/// encountered, an empty string is returned. +static Expected +ReadFull(IOObject &descriptor, size_t length, + std::optional timeout = std::nullopt) { + if (!descriptor.IsValid()) + return llvm::make_error(); + + bool timeout_supported = true; + // FIXME: SelectHelper does not work with NativeFile on Win32. +#if _WIN32 + timeout_supported = descriptor.GetFdType() == IOObject::eFDTypeSocket; +#endif + + if (timeout && timeout_supported) { + SelectHelper sh; + sh.SetTimeout(*timeout); + sh.FDSetRead(descriptor.GetWaitableHandle()); + Status status = sh.Select(); + if (status.Fail()) { + // Convert timeouts into a specific error. + if (status.GetType() == lldb::eErrorTypePOSIX && + status.GetError() == ETIMEDOUT) + return make_error(); + return status.takeError(); + } + } + + std::string data; + data.resize(length); + Status status = descriptor.Read(data.data(), length); + if (status.Fail()) + return status.takeError(); + + // Read returns '' on EOF. + if (length == 0) + return make_error(); + + // Return the actual number of bytes read. + return data.substr(0, length); +} + +static Expected +ReadUntil(IOObject &descriptor, StringRef delimiter, + std::optional timeout = std::nullopt) { + std::string buffer; + buffer.reserve(delimiter.size() + 1); + while (!llvm::StringRef(buffer).ends_with(delimiter)) { + Expected next = + ReadFull(descriptor, buffer.empty() ? delimiter.size() : 1, timeout); + if (auto Err = next.takeError()) + return std::move(Err); + buffer += *next; + } + return buffer.substr(0, buffer.size() - delimiter.size()); +} + +JSONTransport::JSONTransport(IOObjectSP input, IOObjectSP output) + : m_input(std::move(input)), m_output(std::move(output)) {} + +void JSONTransport::Log(llvm::StringRef message) { + LLDB_LOG(GetLog(LLDBLog::Host), "{0}", message); +} + +Expected +HTTPDelimitedJSONTransport::ReadImpl(const std::chrono::microseconds &timeout) { + if (!m_input || !m_input->IsValid()) + return llvm::make_error(); + + IOObject *input = m_input.get(); + Expected message_header = + ReadFull(*input, kHeaderContentLength.size(), timeout); + if (!message_header) + return message_header.takeError(); + if (*message_header != kHeaderContentLength) + return createStringError(formatv("expected '{0}' and got '{1}'", + kHeaderContentLength, *message_header) + .str()); + + Expected raw_length = ReadUntil(*input, kHeaderSeparator); + if (!raw_length) + return handleErrors(raw_length.takeError(), + [&](const TransportEOFError &E) -> llvm::Error { + return createStringError( + "unexpected EOF while reading header separator"); + }); + + size_t length; + if (!to_integer(*raw_length, length)) + return createStringError( + formatv("invalid content length {0}", *raw_length).str()); + + Expected raw_json = ReadFull(*input, length); + if (!raw_json) + return handleErrors( + raw_json.takeError(), [&](const TransportEOFError &E) -> llvm::Error { + return createStringError("unexpected EOF while reading JSON"); + }); + + Log(llvm::formatv("--> {0}", *raw_json).str()); + + return raw_json; +} + +Error HTTPDelimitedJSONTransport::WriteImpl(const std::string &message) { + if (!m_output || !m_output->IsValid()) + return llvm::make_error(); + + Log(llvm::formatv("<-- {0}", message).str()); + + std::string Output; + raw_string_ostream OS(Output); + OS << kHeaderContentLength << message.length() << kHeaderSeparator << message; + size_t num_bytes = Output.size(); + return m_output->Write(Output.data(), num_bytes).takeError(); +} + +Expected +JSONRPCTransport::ReadImpl(const std::chrono::microseconds &timeout) { + if (!m_input || !m_input->IsValid()) + return make_error(); + + IOObject *input = m_input.get(); + Expected raw_json = + ReadUntil(*input, kMessageSeparator, timeout); + if (!raw_json) + return raw_json.takeError(); + + Log(llvm::formatv("--> {0}", *raw_json).str()); + + return *raw_json; +} + +Error JSONRPCTransport::WriteImpl(const std::string &message) { + if (!m_output || !m_output->IsValid()) + return llvm::make_error(); + + Log(llvm::formatv("<-- {0}", message).str()); + + std::string Output; + llvm::raw_string_ostream OS(Output); + OS << message << kMessageSeparator; + size_t num_bytes = Output.size(); + return m_output->Write(Output.data(), num_bytes).takeError(); +} + +char TransportEOFError::ID; +char TransportTimeoutError::ID; +char TransportInvalidError::ID; diff --git a/lldb/source/Host/common/PipeBase.cpp b/lldb/source/Host/common/PipeBase.cpp index 904a2df12392d..400990f4e41b9 100644 --- a/lldb/source/Host/common/PipeBase.cpp +++ b/lldb/source/Host/common/PipeBase.cpp @@ -11,19 +11,3 @@ using namespace lldb_private; PipeBase::~PipeBase() = default; - -Status PipeBase::OpenAsWriter(llvm::StringRef name, - bool child_process_inherit) { - return OpenAsWriterWithTimeout(name, child_process_inherit, - std::chrono::microseconds::zero()); -} - -Status PipeBase::Write(const void *buf, size_t size, size_t &bytes_written) { - return WriteWithTimeout(buf, size, std::chrono::microseconds::zero(), - bytes_written); -} - -Status PipeBase::Read(void *buf, size_t size, size_t &bytes_read) { - return ReadWithTimeout(buf, size, std::chrono::microseconds::zero(), - bytes_read); -} diff --git a/lldb/source/Host/common/Socket.cpp b/lldb/source/Host/common/Socket.cpp index 1a506aa95b246..77b80da2cb5ea 100644 --- a/lldb/source/Host/common/Socket.cpp +++ b/lldb/source/Host/common/Socket.cpp @@ -10,6 +10,7 @@ #include "lldb/Host/Config.h" #include "lldb/Host/Host.h" +#include "lldb/Host/MainLoop.h" #include "lldb/Host/SocketAddress.h" #include "lldb/Host/common/TCPSocket.h" #include "lldb/Host/common/UDPSocket.h" @@ -102,15 +103,14 @@ Status SharedSocket::CompleteSending(lldb::pid_t child_pid) { "WSADuplicateSocket() failed, error: %d", last_error); } - size_t num_bytes; - Status error = - m_socket_pipe.WriteWithTimeout(&protocol_info, sizeof(protocol_info), - std::chrono::seconds(10), num_bytes); - if (error.Fail()) - return error; - if (num_bytes != sizeof(protocol_info)) + llvm::Expected num_bytes = m_socket_pipe.Write( + &protocol_info, sizeof(protocol_info), std::chrono::seconds(10)); + if (!num_bytes) + return Status::FromError(num_bytes.takeError()); + if (*num_bytes != sizeof(protocol_info)) return Status::FromErrorStringWithFormatv( - "WriteWithTimeout(WSAPROTOCOL_INFO) failed: {0} bytes", num_bytes); + "Write(WSAPROTOCOL_INFO) failed: wrote {0}/{1} bytes", *num_bytes, + sizeof(protocol_info)); #endif return Status(); } @@ -122,16 +122,14 @@ Status SharedSocket::GetNativeSocket(shared_fd_t fd, NativeSocket &socket) { WSAPROTOCOL_INFO protocol_info; { Pipe socket_pipe(fd, LLDB_INVALID_PIPE); - size_t num_bytes; - Status error = - socket_pipe.ReadWithTimeout(&protocol_info, sizeof(protocol_info), - std::chrono::seconds(10), num_bytes); - if (error.Fail()) - return error; - if (num_bytes != sizeof(protocol_info)) { + llvm::Expected num_bytes = socket_pipe.Read( + &protocol_info, sizeof(protocol_info), std::chrono::seconds(10)); + if (!num_bytes) + return Status::FromError(num_bytes.takeError()); + if (*num_bytes != sizeof(protocol_info)) { return Status::FromErrorStringWithFormatv( - "socket_pipe.ReadWithTimeout(WSAPROTOCOL_INFO) failed: {0} bytes", - num_bytes); + "Read(WSAPROTOCOL_INFO) failed: read {0}/{1} bytes", *num_bytes, + sizeof(protocol_info)); } } socket = ::WSASocket(FROM_PROTOCOL_INFO, FROM_PROTOCOL_INFO, @@ -294,7 +292,8 @@ Socket::UdpConnect(llvm::StringRef host_and_port, return UDPSocket::Connect(host_and_port, child_processes_inherit); } -llvm::Expected Socket::DecodeHostAndPort(llvm::StringRef host_and_port) { +llvm::Expected +Socket::DecodeHostAndPort(llvm::StringRef host_and_port) { static llvm::Regex g_regex("([^:]+|\\[[0-9a-fA-F:]+.*\\]):([0-9]+)"); HostAndPort ret; llvm::SmallVector matches; @@ -370,8 +369,8 @@ Status Socket::Write(const void *buf, size_t &num_bytes) { ", src = %p, src_len = %" PRIu64 ", flags = 0) => %" PRIi64 " (error = %s)", static_cast(this), static_cast(m_socket), buf, - static_cast(src_len), - static_cast(bytes_sent), error.AsCString()); + static_cast(src_len), static_cast(bytes_sent), + error.AsCString()); } return error; @@ -443,6 +442,19 @@ NativeSocket Socket::CreateSocket(const int domain, const int type, return sock; } +Status Socket::Accept(Socket *&socket) { + MainLoop accept_loop; + llvm::Expected> expected_handles = + Accept(accept_loop, + [&accept_loop, &socket](std::unique_ptr sock) { + socket = sock.release(); + accept_loop.RequestTermination(); + }); + if (!expected_handles) + return Status::FromError(expected_handles.takeError()); + return accept_loop.Run(); +} + NativeSocket Socket::AcceptSocket(NativeSocket sockfd, struct sockaddr *addr, socklen_t *addrlen, bool child_processes_inherit, Status &error) { @@ -483,3 +495,28 @@ llvm::raw_ostream &lldb_private::operator<<(llvm::raw_ostream &OS, const Socket::HostAndPort &HP) { return OS << '[' << HP.hostname << ']' << ':' << HP.port; } + +std::optional +Socket::GetProtocolAndMode(llvm::StringRef scheme) { + // Keep in sync with ConnectionFileDescriptor::Connect. + return llvm::StringSwitch>(scheme) + .Case("listen", ProtocolModePair{SocketProtocol::ProtocolTcp, + SocketMode::ModeAccept}) + .Cases("accept", "unix-accept", + ProtocolModePair{SocketProtocol::ProtocolUnixDomain, + SocketMode::ModeAccept}) + .Case("unix-abstract-accept", + ProtocolModePair{SocketProtocol::ProtocolUnixAbstract, + SocketMode::ModeAccept}) + .Cases("connect", "tcp-connect", + ProtocolModePair{SocketProtocol::ProtocolTcp, + SocketMode::ModeConnect}) + .Case("udp", ProtocolModePair{SocketProtocol::ProtocolTcp, + SocketMode::ModeConnect}) + .Case("unix-connect", ProtocolModePair{SocketProtocol::ProtocolUnixDomain, + SocketMode::ModeConnect}) + .Case("unix-abstract-connect", + ProtocolModePair{SocketProtocol::ProtocolUnixAbstract, + SocketMode::ModeConnect}) + .Default(std::nullopt); +} diff --git a/lldb/source/Host/common/TCPSocket.cpp b/lldb/source/Host/common/TCPSocket.cpp index 1f31190b02f97..fef71810df202 100644 --- a/lldb/source/Host/common/TCPSocket.cpp +++ b/lldb/source/Host/common/TCPSocket.cpp @@ -137,6 +137,14 @@ std::string TCPSocket::GetRemoteConnectionURI() const { return ""; } +std::vector TCPSocket::GetListeningConnectionURI() const { + std::vector URIs; + for (const auto &[fd, addr] : m_listen_sockets) + URIs.emplace_back(llvm::formatv("connection://[{0}]:{1}", + addr.GetIPAddress(), addr.GetPort())); + return URIs; +} + Status TCPSocket::CreateSocket(int domain) { Status error; if (IsValid()) @@ -255,9 +263,9 @@ void TCPSocket::CloseListenSockets() { m_listen_sockets.clear(); } -llvm::Expected> TCPSocket::Accept( - MainLoopBase &loop, - std::function socket)> sock_cb) { +llvm::Expected> +TCPSocket::Accept(MainLoopBase &loop, + std::function socket)> sock_cb) { if (m_listen_sockets.size() == 0) return llvm::createStringError("No open listening sockets!"); @@ -301,19 +309,6 @@ llvm::Expected> TCPSocket::Accept( return handles; } -Status TCPSocket::Accept(Socket *&conn_socket) { - MainLoop accept_loop; - llvm::Expected> expected_handles = - Accept(accept_loop, - [&accept_loop, &conn_socket](std::unique_ptr sock) { - conn_socket = sock.release(); - accept_loop.RequestTermination(); - }); - if (!expected_handles) - return Status::FromError(expected_handles.takeError()); - return accept_loop.Run(); -} - int TCPSocket::SetOptionNoDelay() { return SetOption(IPPROTO_TCP, TCP_NODELAY, 1); } diff --git a/lldb/source/Host/common/UDPSocket.cpp b/lldb/source/Host/common/UDPSocket.cpp index 2a7a6cff414b1..05d7b2e650602 100644 --- a/lldb/source/Host/common/UDPSocket.cpp +++ b/lldb/source/Host/common/UDPSocket.cpp @@ -47,10 +47,6 @@ Status UDPSocket::Listen(llvm::StringRef name, int backlog) { return Status::FromErrorStringWithFormat("%s", g_not_supported_error); } -Status UDPSocket::Accept(Socket *&socket) { - return Status::FromErrorStringWithFormat("%s", g_not_supported_error); -} - llvm::Expected> UDPSocket::Connect(llvm::StringRef name, bool child_processes_inherit) { std::unique_ptr socket; diff --git a/lldb/source/Host/posix/ConnectionFileDescriptorPosix.cpp b/lldb/source/Host/posix/ConnectionFileDescriptorPosix.cpp index d0cc68826d4bb..e0173e90515c5 100644 --- a/lldb/source/Host/posix/ConnectionFileDescriptorPosix.cpp +++ b/lldb/source/Host/posix/ConnectionFileDescriptorPosix.cpp @@ -183,9 +183,7 @@ ConnectionFileDescriptor::Connect(llvm::StringRef path, } bool ConnectionFileDescriptor::InterruptRead() { - size_t bytes_written = 0; - Status result = m_pipe.Write("i", 1, bytes_written); - return result.Success(); + return !errorToBool(m_pipe.Write("i", 1).takeError()); } ConnectionStatus ConnectionFileDescriptor::Disconnect(Status *error_ptr) { @@ -210,13 +208,11 @@ ConnectionStatus ConnectionFileDescriptor::Disconnect(Status *error_ptr) { std::unique_lock locker(m_mutex, std::defer_lock); if (!locker.try_lock()) { if (m_pipe.CanWrite()) { - size_t bytes_written = 0; - Status result = m_pipe.Write("q", 1, bytes_written); - LLDB_LOGF(log, - "%p ConnectionFileDescriptor::Disconnect(): Couldn't get " - "the lock, sent 'q' to %d, error = '%s'.", - static_cast(this), m_pipe.GetWriteFileDescriptor(), - result.AsCString()); + llvm::Error err = m_pipe.Write("q", 1).takeError(); + LLDB_LOG(log, + "{0}: Couldn't get the lock, sent 'q' to {1}, error = '{2}'.", + this, m_pipe.GetWriteFileDescriptor(), err); + consumeError(std::move(err)); } else if (log) { LLDB_LOGF(log, "%p ConnectionFileDescriptor::Disconnect(): Couldn't get the " diff --git a/lldb/source/Host/posix/DomainSocket.cpp b/lldb/source/Host/posix/DomainSocket.cpp index 2d18995c3bb46..6822932274b31 100644 --- a/lldb/source/Host/posix/DomainSocket.cpp +++ b/lldb/source/Host/posix/DomainSocket.cpp @@ -7,11 +7,13 @@ //===----------------------------------------------------------------------===// #include "lldb/Host/posix/DomainSocket.h" +#include "lldb/Utility/LLDBLog.h" #include "llvm/Support/Errno.h" #include "llvm/Support/FileSystem.h" #include +#include #include #include @@ -57,7 +59,14 @@ static bool SetSockAddr(llvm::StringRef name, const size_t name_offset, } DomainSocket::DomainSocket(bool should_close, bool child_processes_inherit) - : Socket(ProtocolUnixDomain, should_close, child_processes_inherit) {} + : DomainSocket(kInvalidSocketValue, should_close, child_processes_inherit) { +} + +DomainSocket::DomainSocket(NativeSocket socket, bool should_close, + bool child_processes_inherit) + : Socket(ProtocolUnixDomain, should_close, child_processes_inherit) { + m_socket = socket; +} DomainSocket::DomainSocket(SocketProtocol protocol, bool child_processes_inherit) @@ -108,14 +117,31 @@ Status DomainSocket::Listen(llvm::StringRef name, int backlog) { return error; } -Status DomainSocket::Accept(Socket *&socket) { - Status error; - auto conn_fd = AcceptSocket(GetNativeSocket(), nullptr, nullptr, - m_child_processes_inherit, error); - if (error.Success()) - socket = new DomainSocket(conn_fd, *this); +llvm::Expected> DomainSocket::Accept( + MainLoopBase &loop, + std::function socket)> sock_cb) { + // TODO: Refactor MainLoop to avoid the shared_ptr requirement. + auto io_sp = std::make_shared(GetNativeSocket(), false, + m_child_processes_inherit); + auto cb = [this, sock_cb](MainLoopBase &loop) { + Log *log = GetLog(LLDBLog::Host); + Status error; + auto conn_fd = AcceptSocket(GetNativeSocket(), nullptr, nullptr, + m_child_processes_inherit, error); + if (error.Fail()) { + LLDB_LOG(log, "AcceptSocket({0}): {1}", GetNativeSocket(), error); + return; + } + std::unique_ptr sock_up(new DomainSocket(conn_fd, *this)); + sock_cb(std::move(sock_up)); + }; - return error; + Status error; + std::vector handles; + handles.emplace_back(loop.RegisterReadObject(io_sp, cb, error)); + if (error.Fail()) + return error.ToError(); + return handles; } size_t DomainSocket::GetNameOffset() const { return 0; } @@ -155,3 +181,17 @@ std::string DomainSocket::GetRemoteConnectionURI() const { "{0}://{1}", GetNameOffset() == 0 ? "unix-connect" : "unix-abstract-connect", name); } + +std::vector DomainSocket::GetListeningConnectionURI() const { + if (m_socket == kInvalidSocketValue) + return {}; + + struct sockaddr_un addr; + bzero(&addr, sizeof(struct sockaddr_un)); + addr.sun_family = AF_UNIX; + socklen_t addr_len = sizeof(struct sockaddr_un); + if (::getsockname(m_socket, (struct sockaddr *)&addr, &addr_len) != 0) + return {}; + + return {llvm::formatv("unix-connect://{0}", addr.sun_path)}; +} diff --git a/lldb/source/Host/posix/MainLoopPosix.cpp b/lldb/source/Host/posix/MainLoopPosix.cpp index 816581e70294a..3106f6e7c0e11 100644 --- a/lldb/source/Host/posix/MainLoopPosix.cpp +++ b/lldb/source/Host/posix/MainLoopPosix.cpp @@ -404,9 +404,5 @@ void MainLoopPosix::TriggerPendingCallbacks() { return; char c = '.'; - size_t bytes_written; - Status error = m_trigger_pipe.Write(&c, 1, bytes_written); - assert(error.Success()); - UNUSED_IF_ASSERT_DISABLED(error); - assert(bytes_written == 1); + cantFail(m_trigger_pipe.Write(&c, 1)); } diff --git a/lldb/source/Host/posix/PipePosix.cpp b/lldb/source/Host/posix/PipePosix.cpp index 24c563d8c24bd..a8c4f8df333a4 100644 --- a/lldb/source/Host/posix/PipePosix.cpp +++ b/lldb/source/Host/posix/PipePosix.cpp @@ -12,7 +12,9 @@ #include "lldb/Utility/SelectHelper.h" #include "llvm/ADT/SmallString.h" #include "llvm/Support/Errno.h" +#include "llvm/Support/Error.h" #include +#include #include #include @@ -164,26 +166,27 @@ Status PipePosix::OpenAsReader(llvm::StringRef name, return error; } -Status -PipePosix::OpenAsWriterWithTimeout(llvm::StringRef name, - bool child_process_inherit, - const std::chrono::microseconds &timeout) { +llvm::Error PipePosix::OpenAsWriter(llvm::StringRef name, + bool child_process_inherit, + const Timeout &timeout) { std::lock_guard guard(m_write_mutex); if (CanReadUnlocked() || CanWriteUnlocked()) - return Status::FromErrorString("Pipe is already opened"); + return llvm::createStringError("Pipe is already opened"); int flags = O_WRONLY | O_NONBLOCK; if (!child_process_inherit) flags |= O_CLOEXEC; using namespace std::chrono; - const auto finish_time = Now() + timeout; + std::optional> finish_time; + if (timeout) + finish_time = Now() + *timeout; while (!CanWriteUnlocked()) { - if (timeout != microseconds::zero()) { - const auto dur = duration_cast(finish_time - Now()).count(); - if (dur <= 0) - return Status::FromErrorString( + if (timeout) { + if (Now() > finish_time) + return llvm::createStringError( + std::make_error_code(std::errc::timed_out), "timeout exceeded - reader hasn't opened so far"); } @@ -193,7 +196,8 @@ PipePosix::OpenAsWriterWithTimeout(llvm::StringRef name, const auto errno_copy = errno; // We may get ENXIO if a reader side of the pipe hasn't opened yet. if (errno_copy != ENXIO && errno_copy != EINTR) - return Status(errno_copy, eErrorTypePOSIX); + return llvm::errorCodeToError( + std::error_code(errno_copy, std::generic_category())); std::this_thread::sleep_for( milliseconds(OPEN_WRITER_SLEEP_TIMEOUT_MSECS)); @@ -202,7 +206,7 @@ PipePosix::OpenAsWriterWithTimeout(llvm::StringRef name, } } - return Status(); + return llvm::Error::success(); } int PipePosix::GetReadFileDescriptor() const { @@ -300,70 +304,51 @@ void PipePosix::CloseWriteFileDescriptorUnlocked() { } } -Status PipePosix::ReadWithTimeout(void *buf, size_t size, - const std::chrono::microseconds &timeout, - size_t &bytes_read) { +llvm::Expected PipePosix::Read(void *buf, size_t size, + const Timeout &timeout) { std::lock_guard guard(m_read_mutex); - bytes_read = 0; if (!CanReadUnlocked()) - return Status(EINVAL, eErrorTypePOSIX); + return llvm::errorCodeToError( + std::make_error_code(std::errc::invalid_argument)); const int fd = GetReadFileDescriptorUnlocked(); SelectHelper select_helper; - select_helper.SetTimeout(timeout); + if (timeout) + select_helper.SetTimeout(*timeout); select_helper.FDSetRead(fd); - Status error; - while (error.Success()) { - error = select_helper.Select(); - if (error.Success()) { - auto result = - ::read(fd, static_cast(buf) + bytes_read, size - bytes_read); - if (result != -1) { - bytes_read += result; - if (bytes_read == size || result == 0) - break; - } else if (errno == EINTR) { - continue; - } else { - error = Status::FromErrno(); - break; - } - } - } - return error; + if (llvm::Error error = select_helper.Select().takeError()) + return error; + + ssize_t result = ::read(fd, buf, size); + if (result == -1) + return llvm::errorCodeToError( + std::error_code(errno, std::generic_category())); + + return result; } -Status PipePosix::WriteWithTimeout(const void *buf, size_t size, - const std::chrono::microseconds &timeout, - size_t &bytes_written) { +llvm::Expected PipePosix::Write(const void *buf, size_t size, + const Timeout &timeout) { std::lock_guard guard(m_write_mutex); - bytes_written = 0; if (!CanWriteUnlocked()) - return Status(EINVAL, eErrorTypePOSIX); + return llvm::errorCodeToError( + std::make_error_code(std::errc::invalid_argument)); const int fd = GetWriteFileDescriptorUnlocked(); SelectHelper select_helper; - select_helper.SetTimeout(timeout); + if (timeout) + select_helper.SetTimeout(*timeout); select_helper.FDSetWrite(fd); - Status error; - while (error.Success()) { - error = select_helper.Select(); - if (error.Success()) { - auto result = ::write(fd, static_cast(buf) + bytes_written, - size - bytes_written); - if (result != -1) { - bytes_written += result; - if (bytes_written == size) - break; - } else if (errno == EINTR) { - continue; - } else { - error = Status::FromErrno(); - } - } - } - return error; + if (llvm::Error error = select_helper.Select().takeError()) + return error; + + ssize_t result = ::write(fd, buf, size); + if (result == -1) + return llvm::errorCodeToError( + std::error_code(errno, std::generic_category())); + + return result; } diff --git a/lldb/source/Host/windows/PipeWindows.cpp b/lldb/source/Host/windows/PipeWindows.cpp index d79dc3c2f82c9..a13929b65e087 100644 --- a/lldb/source/Host/windows/PipeWindows.cpp +++ b/lldb/source/Host/windows/PipeWindows.cpp @@ -151,14 +151,13 @@ Status PipeWindows::OpenAsReader(llvm::StringRef name, return OpenNamedPipe(name, child_process_inherit, true); } -Status -PipeWindows::OpenAsWriterWithTimeout(llvm::StringRef name, - bool child_process_inherit, - const std::chrono::microseconds &timeout) { +llvm::Error PipeWindows::OpenAsWriter(llvm::StringRef name, + bool child_process_inherit, + const Timeout &timeout) { if (CanWrite()) - return Status(); // Note the name is ignored. + return llvm::Error::success(); // Note the name is ignored. - return OpenNamedPipe(name, child_process_inherit, false); + return OpenNamedPipe(name, child_process_inherit, false).takeError(); } Status PipeWindows::OpenNamedPipe(llvm::StringRef name, @@ -270,29 +269,24 @@ PipeWindows::GetReadNativeHandle() { return m_read; } HANDLE PipeWindows::GetWriteNativeHandle() { return m_write; } -Status PipeWindows::ReadWithTimeout(void *buf, size_t size, - const std::chrono::microseconds &duration, - size_t &bytes_read) { +llvm::Expected PipeWindows::Read(void *buf, size_t size, + const Timeout &timeout) { if (!CanRead()) - return Status(ERROR_INVALID_HANDLE, eErrorTypeWin32); + return Status(ERROR_INVALID_HANDLE, eErrorTypeWin32).takeError(); - bytes_read = 0; - DWORD sys_bytes_read = 0; - BOOL result = - ::ReadFile(m_read, buf, size, &sys_bytes_read, &m_read_overlapped); - if (result) { - bytes_read = sys_bytes_read; - return Status(); - } + DWORD bytes_read = 0; + BOOL result = ::ReadFile(m_read, buf, size, &bytes_read, &m_read_overlapped); + if (result) + return bytes_read; DWORD failure_error = ::GetLastError(); if (failure_error != ERROR_IO_PENDING) - return Status(failure_error, eErrorTypeWin32); + return Status(failure_error, eErrorTypeWin32).takeError(); - DWORD timeout = (duration == std::chrono::microseconds::zero()) - ? INFINITE - : duration.count() / 1000; - DWORD wait_result = ::WaitForSingleObject(m_read_overlapped.hEvent, timeout); + DWORD timeout_msec = + timeout ? ceil(*timeout).count() : INFINITE; + DWORD wait_result = + ::WaitForSingleObject(m_read_overlapped.hEvent, timeout_msec); if (wait_result != WAIT_OBJECT_0) { // The operation probably failed. However, if it timed out, we need to // cancel the I/O. Between the time we returned from WaitForSingleObject @@ -308,42 +302,36 @@ Status PipeWindows::ReadWithTimeout(void *buf, size_t size, failed = false; } if (failed) - return Status(failure_error, eErrorTypeWin32); + return Status(failure_error, eErrorTypeWin32).takeError(); } // Now we call GetOverlappedResult setting bWait to false, since we've // already waited as long as we're willing to. - if (!::GetOverlappedResult(m_read, &m_read_overlapped, &sys_bytes_read, - FALSE)) - return Status(::GetLastError(), eErrorTypeWin32); + if (!::GetOverlappedResult(m_read, &m_read_overlapped, &bytes_read, FALSE)) + return Status(::GetLastError(), eErrorTypeWin32).takeError(); - bytes_read = sys_bytes_read; - return Status(); + return bytes_read; } -Status PipeWindows::WriteWithTimeout(const void *buf, size_t size, - const std::chrono::microseconds &duration, - size_t &bytes_written) { +llvm::Expected PipeWindows::Write(const void *buf, size_t size, + const Timeout &timeout) { if (!CanWrite()) - return Status(ERROR_INVALID_HANDLE, eErrorTypeWin32); + return Status(ERROR_INVALID_HANDLE, eErrorTypeWin32).takeError(); - bytes_written = 0; - DWORD sys_bytes_write = 0; + DWORD bytes_written = 0; BOOL result = - ::WriteFile(m_write, buf, size, &sys_bytes_write, &m_write_overlapped); - if (result) { - bytes_written = sys_bytes_write; - return Status(); - } + ::WriteFile(m_write, buf, size, &bytes_written, &m_write_overlapped); + if (result) + return bytes_written; DWORD failure_error = ::GetLastError(); if (failure_error != ERROR_IO_PENDING) - return Status(failure_error, eErrorTypeWin32); + return Status(failure_error, eErrorTypeWin32).takeError(); - DWORD timeout = (duration == std::chrono::microseconds::zero()) - ? INFINITE - : duration.count() / 1000; - DWORD wait_result = ::WaitForSingleObject(m_write_overlapped.hEvent, timeout); + DWORD timeout_msec = + timeout ? ceil(*timeout).count() : INFINITE; + DWORD wait_result = + ::WaitForSingleObject(m_write_overlapped.hEvent, timeout_msec); if (wait_result != WAIT_OBJECT_0) { // The operation probably failed. However, if it timed out, we need to // cancel the I/O. Between the time we returned from WaitForSingleObject @@ -359,15 +347,14 @@ Status PipeWindows::WriteWithTimeout(const void *buf, size_t size, failed = false; } if (failed) - return Status(failure_error, eErrorTypeWin32); + return Status(failure_error, eErrorTypeWin32).takeError(); } // Now we call GetOverlappedResult setting bWait to false, since we've // already waited as long as we're willing to. - if (!::GetOverlappedResult(m_write, &m_write_overlapped, &sys_bytes_write, + if (!::GetOverlappedResult(m_write, &m_write_overlapped, &bytes_written, FALSE)) - return Status(::GetLastError(), eErrorTypeWin32); + return Status(::GetLastError(), eErrorTypeWin32).takeError(); - bytes_written = sys_bytes_write; - return Status(); + return bytes_written; } diff --git a/lldb/source/Interpreter/CommandInterpreter.cpp b/lldb/source/Interpreter/CommandInterpreter.cpp index 68831d2831749..231b9c08d7150 100644 --- a/lldb/source/Interpreter/CommandInterpreter.cpp +++ b/lldb/source/Interpreter/CommandInterpreter.cpp @@ -31,6 +31,7 @@ #include "Commands/CommandObjectPlatform.h" #include "Commands/CommandObjectPlugin.h" #include "Commands/CommandObjectProcess.h" +#include "Commands/CommandObjectProtocolServer.h" #include "Commands/CommandObjectQuit.h" #include "Commands/CommandObjectRegexCommand.h" #include "Commands/CommandObjectRegister.h" @@ -583,6 +584,7 @@ void CommandInterpreter::LoadCommandDictionary() { REGISTER_COMMAND_OBJECT("platform", CommandObjectPlatform); REGISTER_COMMAND_OBJECT("plugin", CommandObjectPlugin); REGISTER_COMMAND_OBJECT("process", CommandObjectMultiwordProcess); + REGISTER_COMMAND_OBJECT("protocol-server", CommandObjectProtocolServer); REGISTER_COMMAND_OBJECT("quit", CommandObjectQuit); REGISTER_COMMAND_OBJECT("register", CommandObjectRegister); REGISTER_COMMAND_OBJECT("scripting", CommandObjectMultiwordScripting); diff --git a/lldb/source/Plugins/CMakeLists.txt b/lldb/source/Plugins/CMakeLists.txt index 854f589f45ae0..08f444e7b15e8 100644 --- a/lldb/source/Plugins/CMakeLists.txt +++ b/lldb/source/Plugins/CMakeLists.txt @@ -27,6 +27,10 @@ add_subdirectory(TraceExporter) add_subdirectory(TypeSystem) add_subdirectory(UnwindAssembly) +if(LLDB_ENABLE_PROTOCOL_SERVERS) + add_subdirectory(Protocol) +endif() + set(LLDB_STRIPPED_PLUGINS) get_property(LLDB_ALL_PLUGINS GLOBAL PROPERTY LLDB_PLUGINS) diff --git a/lldb/source/Plugins/Process/gdb-remote/GDBRemoteCommunication.cpp b/lldb/source/Plugins/Process/gdb-remote/GDBRemoteCommunication.cpp index d39ae79fd84f9..a7a04bb521697 100644 --- a/lldb/source/Plugins/Process/gdb-remote/GDBRemoteCommunication.cpp +++ b/lldb/source/Plugins/Process/gdb-remote/GDBRemoteCommunication.cpp @@ -30,7 +30,11 @@ #include "lldb/Utility/Log.h" #include "lldb/Utility/RegularExpression.h" #include "lldb/Utility/StreamString.h" +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallString.h" +#include "llvm/Config/llvm-config.h" // for LLVM_ENABLE_ZLIB +#include "llvm/ADT/StringRef.h" +#include "llvm/Config/llvm-config.h" // for LLVM_ENABLE_ZLIB #include "llvm/Support/ScopedPrinter.h" #include "ProcessGDBRemoteLog.h" @@ -1147,15 +1151,25 @@ Status GDBRemoteCommunication::StartDebugserverProcess( if (socket_pipe.CanRead()) { char port_cstr[PATH_MAX] = {0}; port_cstr[0] = '\0'; - size_t num_bytes = sizeof(port_cstr); // Read port from pipe with 10 second timeout. - error = socket_pipe.ReadWithTimeout( - port_cstr, num_bytes, std::chrono::seconds{10}, num_bytes); + std::string port_str; + while (error.Success()) { + char buf[10]; + if (llvm::Expected num_bytes = socket_pipe.Read( + buf, std::size(buf), std::chrono::seconds(10))) { + if (*num_bytes == 0) + break; + port_str.append(buf, *num_bytes); + } else { + error = Status::FromError(num_bytes.takeError()); + } + } if (error.Success() && (port != nullptr)) { - assert(num_bytes > 0 && port_cstr[num_bytes - 1] == '\0'); + // NB: Deliberately using .c_str() to stop at embedded '\0's + llvm::StringRef port_ref = port_str.c_str(); uint16_t child_port = 0; // FIXME: improve error handling - llvm::to_integer(port_cstr, child_port); + llvm::to_integer(port_ref, child_port); if (*port == 0 || *port == child_port) { *port = child_port; LLDB_LOGF(log, diff --git a/lldb/source/Plugins/Protocol/CMakeLists.txt b/lldb/source/Plugins/Protocol/CMakeLists.txt new file mode 100644 index 0000000000000..93b347d4cc9d8 --- /dev/null +++ b/lldb/source/Plugins/Protocol/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(MCP) diff --git a/lldb/source/Plugins/Protocol/MCP/CMakeLists.txt b/lldb/source/Plugins/Protocol/MCP/CMakeLists.txt new file mode 100644 index 0000000000000..e104fb527e57a --- /dev/null +++ b/lldb/source/Plugins/Protocol/MCP/CMakeLists.txt @@ -0,0 +1,14 @@ +add_lldb_library(lldbPluginProtocolServerMCP PLUGIN + MCPError.cpp + Protocol.cpp + ProtocolServerMCP.cpp + Resource.cpp + Tool.cpp + + LINK_COMPONENTS + Support + + LINK_LIBS + lldbHost + lldbUtility +) diff --git a/lldb/source/Plugins/Protocol/MCP/MCPError.cpp b/lldb/source/Plugins/Protocol/MCP/MCPError.cpp new file mode 100644 index 0000000000000..659b53a14fe23 --- /dev/null +++ b/lldb/source/Plugins/Protocol/MCP/MCPError.cpp @@ -0,0 +1,45 @@ +//===-- MCPError.cpp ------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "MCPError.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/raw_ostream.h" +#include + +namespace lldb_private::mcp { + +char MCPError::ID; +char UnsupportedURI::ID; + +MCPError::MCPError(std::string message, int64_t error_code) + : m_message(message), m_error_code(error_code) {} + +void MCPError::log(llvm::raw_ostream &OS) const { OS << m_message; } + +std::error_code MCPError::convertToErrorCode() const { + return llvm::inconvertibleErrorCode(); +} + +protocol::Error MCPError::toProtcolError() const { + protocol::Error error; + error.error.code = m_error_code; + error.error.message = m_message; + return error; +} + +UnsupportedURI::UnsupportedURI(std::string uri) : m_uri(uri) {} + +void UnsupportedURI::log(llvm::raw_ostream &OS) const { + OS << "unsupported uri: " << m_uri; +} + +std::error_code UnsupportedURI::convertToErrorCode() const { + return llvm::inconvertibleErrorCode(); +} + +} // namespace lldb_private::mcp diff --git a/lldb/source/Plugins/Protocol/MCP/MCPError.h b/lldb/source/Plugins/Protocol/MCP/MCPError.h new file mode 100644 index 0000000000000..f4db13d6deade --- /dev/null +++ b/lldb/source/Plugins/Protocol/MCP/MCPError.h @@ -0,0 +1,50 @@ +//===-- MCPError.h --------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "Protocol.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/FormatVariadic.h" +#include + +namespace lldb_private::mcp { + +class MCPError : public llvm::ErrorInfo { +public: + static char ID; + + MCPError(std::string message, int64_t error_code = kInternalError); + + void log(llvm::raw_ostream &OS) const override; + std::error_code convertToErrorCode() const override; + + const std::string &getMessage() const { return m_message; } + + protocol::Error toProtcolError() const; + + static constexpr int64_t kResourceNotFound = -32002; + static constexpr int64_t kInternalError = -32603; + +private: + std::string m_message; + int64_t m_error_code; +}; + +class UnsupportedURI : public llvm::ErrorInfo { +public: + static char ID; + + UnsupportedURI(std::string uri); + + void log(llvm::raw_ostream &OS) const override; + std::error_code convertToErrorCode() const override; + +private: + std::string m_uri; +}; + +} // namespace lldb_private::mcp diff --git a/lldb/source/Plugins/Protocol/MCP/Protocol.cpp b/lldb/source/Plugins/Protocol/MCP/Protocol.cpp new file mode 100644 index 0000000000000..274ba6fac01ec --- /dev/null +++ b/lldb/source/Plugins/Protocol/MCP/Protocol.cpp @@ -0,0 +1,266 @@ +//===- Protocol.cpp -------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "Protocol.h" +#include "llvm/Support/JSON.h" + +using namespace llvm; + +namespace lldb_private::mcp::protocol { + +static bool mapRaw(const json::Value &Params, StringLiteral Prop, + std::optional &V, json::Path P) { + const auto *O = Params.getAsObject(); + if (!O) { + P.report("expected object"); + return false; + } + const json::Value *E = O->get(Prop); + if (E) + V = std::move(*E); + return true; +} + +llvm::json::Value toJSON(const Request &R) { + json::Object Result{{"jsonrpc", "2.0"}, {"id", R.id}, {"method", R.method}}; + if (R.params) + Result.insert({"params", R.params}); + return Result; +} + +bool fromJSON(const llvm::json::Value &V, Request &R, llvm::json::Path P) { + llvm::json::ObjectMapper O(V, P); + if (!O || !O.map("id", R.id) || !O.map("method", R.method)) + return false; + return mapRaw(V, "params", R.params, P); +} + +llvm::json::Value toJSON(const ErrorInfo &EI) { + llvm::json::Object Result{{"code", EI.code}, {"message", EI.message}}; + if (!EI.data.empty()) + Result.insert({"data", EI.data}); + return Result; +} + +bool fromJSON(const llvm::json::Value &V, ErrorInfo &EI, llvm::json::Path P) { + llvm::json::ObjectMapper O(V, P); + return O && O.map("code", EI.code) && O.map("message", EI.message) && + O.mapOptional("data", EI.data); +} + +llvm::json::Value toJSON(const Error &E) { + return json::Object{{"jsonrpc", "2.0"}, {"id", E.id}, {"error", E.error}}; +} + +bool fromJSON(const llvm::json::Value &V, Error &E, llvm::json::Path P) { + llvm::json::ObjectMapper O(V, P); + return O && O.map("id", E.id) && O.map("error", E.error); +} + +llvm::json::Value toJSON(const Response &R) { + llvm::json::Object Result{{"jsonrpc", "2.0"}, {"id", R.id}}; + if (R.result) + Result.insert({"result", R.result}); + if (R.error) + Result.insert({"error", R.error}); + return Result; +} + +bool fromJSON(const llvm::json::Value &V, Response &R, llvm::json::Path P) { + llvm::json::ObjectMapper O(V, P); + if (!O || !O.map("id", R.id) || !O.map("error", R.error)) + return false; + return mapRaw(V, "result", R.result, P); +} + +llvm::json::Value toJSON(const Notification &N) { + llvm::json::Object Result{{"jsonrpc", "2.0"}, {"method", N.method}}; + if (N.params) + Result.insert({"params", N.params}); + return Result; +} + +bool fromJSON(const llvm::json::Value &V, Notification &N, llvm::json::Path P) { + llvm::json::ObjectMapper O(V, P); + if (!O || !O.map("method", N.method)) + return false; + auto *Obj = V.getAsObject(); + if (!Obj) + return false; + if (auto *Params = Obj->get("params")) + N.params = *Params; + return true; +} + +llvm::json::Value toJSON(const ToolCapability &TC) { + return llvm::json::Object{{"listChanged", TC.listChanged}}; +} + +bool fromJSON(const llvm::json::Value &V, ToolCapability &TC, + llvm::json::Path P) { + llvm::json::ObjectMapper O(V, P); + return O && O.map("listChanged", TC.listChanged); +} + +llvm::json::Value toJSON(const ResourceCapability &RC) { + return llvm::json::Object{{"listChanged", RC.listChanged}, + {"subscribe", RC.subscribe}}; +} + +bool fromJSON(const llvm::json::Value &V, ResourceCapability &RC, + llvm::json::Path P) { + llvm::json::ObjectMapper O(V, P); + return O && O.map("listChanged", RC.listChanged) && + O.map("subscribe", RC.subscribe); +} + +llvm::json::Value toJSON(const Capabilities &C) { + return llvm::json::Object{{"tools", C.tools}, {"resources", C.resources}}; +} + +bool fromJSON(const llvm::json::Value &V, Resource &R, llvm::json::Path P) { + llvm::json::ObjectMapper O(V, P); + return O && O.map("uri", R.uri) && O.map("name", R.name) && + O.mapOptional("description", R.description) && + O.mapOptional("mimeType", R.mimeType); +} + +llvm::json::Value toJSON(const Resource &R) { + llvm::json::Object Result{{"uri", R.uri}, {"name", R.name}}; + if (!R.description.empty()) + Result.insert({"description", R.description}); + if (!R.mimeType.empty()) + Result.insert({"mimeType", R.mimeType}); + return Result; +} + +bool fromJSON(const llvm::json::Value &V, Capabilities &C, llvm::json::Path P) { + llvm::json::ObjectMapper O(V, P); + return O && O.map("tools", C.tools); +} + +llvm::json::Value toJSON(const ResourceContents &RC) { + llvm::json::Object Result{{"uri", RC.uri}, {"text", RC.text}}; + if (!RC.mimeType.empty()) + Result.insert({"mimeType", RC.mimeType}); + return Result; +} + +bool fromJSON(const llvm::json::Value &V, ResourceContents &RC, + llvm::json::Path P) { + llvm::json::ObjectMapper O(V, P); + return O && O.map("uri", RC.uri) && O.map("text", RC.text) && + O.mapOptional("mimeType", RC.mimeType); +} + +llvm::json::Value toJSON(const ResourceResult &RR) { + return llvm::json::Object{{"contents", RR.contents}}; +} + +bool fromJSON(const llvm::json::Value &V, ResourceResult &RR, + llvm::json::Path P) { + llvm::json::ObjectMapper O(V, P); + return O && O.map("contents", RR.contents); +} + +llvm::json::Value toJSON(const TextContent &TC) { + return llvm::json::Object{{"type", "text"}, {"text", TC.text}}; +} + +bool fromJSON(const llvm::json::Value &V, TextContent &TC, llvm::json::Path P) { + llvm::json::ObjectMapper O(V, P); + return O && O.map("text", TC.text); +} + +llvm::json::Value toJSON(const TextResult &TR) { + return llvm::json::Object{{"content", TR.content}, {"isError", TR.isError}}; +} + +bool fromJSON(const llvm::json::Value &V, TextResult &TR, llvm::json::Path P) { + llvm::json::ObjectMapper O(V, P); + return O && O.map("content", TR.content) && O.map("isError", TR.isError); +} + +llvm::json::Value toJSON(const ToolDefinition &TD) { + llvm::json::Object Result{{"name", TD.name}}; + if (!TD.description.empty()) + Result.insert({"description", TD.description}); + if (TD.inputSchema) + Result.insert({"inputSchema", TD.inputSchema}); + return Result; +} + +bool fromJSON(const llvm::json::Value &V, ToolDefinition &TD, + llvm::json::Path P) { + + llvm::json::ObjectMapper O(V, P); + if (!O || !O.map("name", TD.name) || + !O.mapOptional("description", TD.description)) + return false; + return mapRaw(V, "inputSchema", TD.inputSchema, P); +} + +llvm::json::Value toJSON(const Message &M) { + return std::visit([](auto &M) { return toJSON(M); }, M); +} + +bool fromJSON(const llvm::json::Value &V, Message &M, llvm::json::Path P) { + const auto *O = V.getAsObject(); + if (!O) { + P.report("expected object"); + return false; + } + + if (const json::Value *V = O->get("jsonrpc")) { + if (V->getAsString().value_or("") != "2.0") { + P.report("unsupported JSON RPC version"); + return false; + } + } else { + P.report("not a valid JSON RPC message"); + return false; + } + + // A message without an ID is a Notification. + if (!O->get("id")) { + protocol::Notification N; + if (!fromJSON(V, N, P)) + return false; + M = std::move(N); + return true; + } + + if (O->get("error")) { + protocol::Error E; + if (!fromJSON(V, E, P)) + return false; + M = std::move(E); + return true; + } + + if (O->get("result")) { + protocol::Response R; + if (!fromJSON(V, R, P)) + return false; + M = std::move(R); + return true; + } + + if (O->get("method")) { + protocol::Request R; + if (!fromJSON(V, R, P)) + return false; + M = std::move(R); + return true; + } + + P.report("unrecognized message type"); + return false; +} + +} // namespace lldb_private::mcp::protocol diff --git a/lldb/source/Plugins/Protocol/MCP/Protocol.h b/lldb/source/Plugins/Protocol/MCP/Protocol.h new file mode 100644 index 0000000000000..ce74836e62541 --- /dev/null +++ b/lldb/source/Plugins/Protocol/MCP/Protocol.h @@ -0,0 +1,188 @@ +//===- Protocol.h ---------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains POD structs based on the MCP specification at +// https://github.com/modelcontextprotocol/modelcontextprotocol/blob/main/schema/2024-11-05/schema.json +// +//===----------------------------------------------------------------------===// + +#ifndef LLDB_PLUGINS_PROTOCOL_MCP_PROTOCOL_H +#define LLDB_PLUGINS_PROTOCOL_MCP_PROTOCOL_H + +#include "llvm/Support/JSON.h" +#include +#include +#include + +namespace lldb_private::mcp::protocol { + +static llvm::StringLiteral kVersion = "2024-11-05"; + +/// A request that expects a response. +struct Request { + uint64_t id = 0; + std::string method; + std::optional params; +}; + +llvm::json::Value toJSON(const Request &); +bool fromJSON(const llvm::json::Value &, Request &, llvm::json::Path); + +struct ErrorInfo { + int64_t code = 0; + std::string message; + std::string data; +}; + +llvm::json::Value toJSON(const ErrorInfo &); +bool fromJSON(const llvm::json::Value &, ErrorInfo &, llvm::json::Path); + +struct Error { + uint64_t id = 0; + ErrorInfo error; +}; + +llvm::json::Value toJSON(const Error &); +bool fromJSON(const llvm::json::Value &, Error &, llvm::json::Path); + +struct Response { + uint64_t id = 0; + std::optional result; + std::optional error; +}; + +llvm::json::Value toJSON(const Response &); +bool fromJSON(const llvm::json::Value &, Response &, llvm::json::Path); + +/// A notification which does not expect a response. +struct Notification { + std::string method; + std::optional params; +}; + +llvm::json::Value toJSON(const Notification &); +bool fromJSON(const llvm::json::Value &, Notification &, llvm::json::Path); + +struct ToolCapability { + /// Whether this server supports notifications for changes to the tool list. + bool listChanged = false; +}; + +llvm::json::Value toJSON(const ToolCapability &); +bool fromJSON(const llvm::json::Value &, ToolCapability &, llvm::json::Path); + +struct ResourceCapability { + /// Whether this server supports notifications for changes to the resources + /// list. + bool listChanged = false; + + /// Whether subscriptions are supported. + bool subscribe = false; +}; + +llvm::json::Value toJSON(const ResourceCapability &); +bool fromJSON(const llvm::json::Value &, ResourceCapability &, + llvm::json::Path); + +/// Capabilities that a server may support. Known capabilities are defined here, +/// in this schema, but this is not a closed set: any server can define its own, +/// additional capabilities. +struct Capabilities { + /// Tool capabilities of the server. + ToolCapability tools; + + /// Resource capabilities of the server. + ResourceCapability resources; +}; + +llvm::json::Value toJSON(const Capabilities &); +bool fromJSON(const llvm::json::Value &, Capabilities &, llvm::json::Path); + +/// A known resource that the server is capable of reading. +struct Resource { + /// The URI of this resource. + std::string uri; + + /// A human-readable name for this resource. + std::string name; + + /// A description of what this resource represents. + std::string description; + + /// The MIME type of this resource, if known. + std::string mimeType; +}; + +llvm::json::Value toJSON(const Resource &); +bool fromJSON(const llvm::json::Value &, Resource &, llvm::json::Path); + +/// The contents of a specific resource or sub-resource. +struct ResourceContents { + /// The URI of this resource. + std::string uri; + + /// The text of the item. This must only be set if the item can actually be + /// represented as text (not binary data). + std::string text; + + /// The MIME type of this resource, if known. + std::string mimeType; +}; + +llvm::json::Value toJSON(const ResourceContents &); +bool fromJSON(const llvm::json::Value &, ResourceContents &, llvm::json::Path); + +/// The server's response to a resources/read request from the client. +struct ResourceResult { + std::vector contents; +}; + +llvm::json::Value toJSON(const ResourceResult &); +bool fromJSON(const llvm::json::Value &, ResourceResult &, llvm::json::Path); + +/// Text provided to or from an LLM. +struct TextContent { + /// The text content of the message. + std::string text; +}; + +llvm::json::Value toJSON(const TextContent &); +bool fromJSON(const llvm::json::Value &, TextContent &, llvm::json::Path); + +struct TextResult { + std::vector content; + bool isError = false; +}; + +llvm::json::Value toJSON(const TextResult &); +bool fromJSON(const llvm::json::Value &, TextResult &, llvm::json::Path); + +struct ToolDefinition { + /// Unique identifier for the tool. + std::string name; + + /// Human-readable description. + std::string description; + + // JSON Schema for the tool's parameters. + std::optional inputSchema; +}; + +llvm::json::Value toJSON(const ToolDefinition &); +bool fromJSON(const llvm::json::Value &, ToolDefinition &, llvm::json::Path); + +using Message = std::variant; + +bool fromJSON(const llvm::json::Value &, Message &, llvm::json::Path); +llvm::json::Value toJSON(const Message &); + +using ToolArguments = std::variant; + +} // namespace lldb_private::mcp::protocol + +#endif diff --git a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp new file mode 100644 index 0000000000000..0d79dcdad2d65 --- /dev/null +++ b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp @@ -0,0 +1,412 @@ +//===- ProtocolServerMCP.cpp ----------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "ProtocolServerMCP.h" +#include "MCPError.h" +#include "lldb/Core/PluginManager.h" +#include "lldb/Utility/LLDBLog.h" +#include "lldb/Utility/Log.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/Threading.h" +#include +#include + +using namespace lldb_private; +using namespace lldb_private::mcp; +using namespace llvm; + +LLDB_PLUGIN_DEFINE(ProtocolServerMCP) + +static constexpr size_t kChunkSize = 1024; + +ProtocolServerMCP::ProtocolServerMCP() : ProtocolServer() { + AddRequestHandler("initialize", + std::bind(&ProtocolServerMCP::InitializeHandler, this, + std::placeholders::_1)); + + AddRequestHandler("tools/list", + std::bind(&ProtocolServerMCP::ToolsListHandler, this, + std::placeholders::_1)); + AddRequestHandler("tools/call", + std::bind(&ProtocolServerMCP::ToolsCallHandler, this, + std::placeholders::_1)); + + AddRequestHandler("resources/list", + std::bind(&ProtocolServerMCP::ResourcesListHandler, this, + std::placeholders::_1)); + AddRequestHandler("resources/read", + std::bind(&ProtocolServerMCP::ResourcesReadHandler, this, + std::placeholders::_1)); + AddNotificationHandler( + "notifications/initialized", [](const protocol::Notification &) { + LLDB_LOG(GetLog(LLDBLog::Host), "MCP initialization complete"); + }); + + AddTool( + std::make_unique("lldb_command", "Run an lldb command.")); + + AddResourceProvider(std::make_unique()); +} + +ProtocolServerMCP::~ProtocolServerMCP() { llvm::consumeError(Stop()); } + +void ProtocolServerMCP::Initialize() { + PluginManager::RegisterPlugin(GetPluginNameStatic(), + GetPluginDescriptionStatic(), CreateInstance); +} + +void ProtocolServerMCP::Terminate() { + PluginManager::UnregisterPlugin(CreateInstance); +} + +lldb::ProtocolServerUP ProtocolServerMCP::CreateInstance() { + return std::make_unique(); +} + +llvm::StringRef ProtocolServerMCP::GetPluginDescriptionStatic() { + return "MCP Server."; +} + +llvm::Expected +ProtocolServerMCP::Handle(protocol::Request request) { + auto it = m_request_handlers.find(request.method); + if (it != m_request_handlers.end()) { + llvm::Expected response = it->second(request); + if (!response) + return response; + response->id = request.id; + return *response; + } + + return make_error( + llvm::formatv("no handler for request: {0}", request.method).str()); +} + +void ProtocolServerMCP::Handle(protocol::Notification notification) { + auto it = m_notification_handlers.find(notification.method); + if (it != m_notification_handlers.end()) { + it->second(notification); + return; + } + + LLDB_LOG(GetLog(LLDBLog::Host), "MPC notification: {0} ({1})", + notification.method, notification.params); +} + +void ProtocolServerMCP::AcceptCallback(std::unique_ptr socket) { + LLDB_LOG(GetLog(LLDBLog::Host), "New MCP client ({0}) connected", + m_clients.size() + 1); + + lldb::IOObjectSP io_sp = std::move(socket); + auto client_up = std::make_unique(); + client_up->io_sp = io_sp; + Client *client = client_up.get(); + + Status status; + auto read_handle_up = m_loop.RegisterReadObject( + io_sp, + [this, client](MainLoopBase &loop) { + if (Error error = ReadCallback(*client)) { + LLDB_LOG_ERROR(GetLog(LLDBLog::Host), std::move(error), "{0}"); + client->read_handle_up.reset(); + } + }, + status); + if (status.Fail()) + return; + + client_up->read_handle_up = std::move(read_handle_up); + m_clients.emplace_back(std::move(client_up)); +} + +llvm::Error ProtocolServerMCP::ReadCallback(Client &client) { + char chunk[kChunkSize]; + size_t bytes_read = sizeof(chunk); + if (Status status = client.io_sp->Read(chunk, bytes_read); status.Fail()) + return status.takeError(); + client.buffer.append(chunk, bytes_read); + + for (std::string::size_type pos; + (pos = client.buffer.find('\n')) != std::string::npos;) { + llvm::Expected> message = + HandleData(StringRef(client.buffer.data(), pos)); + client.buffer = client.buffer.erase(0, pos + 1); + if (!message) + return message.takeError(); + + if (*message) { + std::string Output; + llvm::raw_string_ostream OS(Output); + OS << llvm::formatv("{0}", toJSON(**message)) << '\n'; + size_t num_bytes = Output.size(); + return client.io_sp->Write(Output.data(), num_bytes).takeError(); + } + } + + return llvm::Error::success(); +} + +llvm::Error ProtocolServerMCP::Start(ProtocolServer::Connection connection) { + std::lock_guard guard(m_server_mutex); + + if (m_running) + return llvm::createStringError("the MCP server is already running"); + + Status status; + m_listener = Socket::Create(connection.protocol, false, status); + if (status.Fail()) + return status.takeError(); + + status = m_listener->Listen(connection.name, /*backlog=*/5); + if (status.Fail()) + return status.takeError(); + + std::string address = + llvm::join(m_listener->GetListeningConnectionURI(), ", "); + auto handles = + m_listener->Accept(m_loop, std::bind(&ProtocolServerMCP::AcceptCallback, + this, std::placeholders::_1)); + if (llvm::Error error = handles.takeError()) + return error; + + m_running = true; + m_listen_handlers = std::move(*handles); + m_loop_thread = std::thread([=] { + llvm::set_thread_name("protocol-server.mcp"); + m_loop.Run(); + }); + + return llvm::Error::success(); +} + +llvm::Error ProtocolServerMCP::Stop() { + { + std::lock_guard guard(m_server_mutex); + if (!m_running) + return createStringError("the MCP sever is not running"); + m_running = false; + } + + // Stop the main loop. + m_loop.AddPendingCallback( + [](MainLoopBase &loop) { loop.RequestTermination(); }); + + // Wait for the main loop to exit. + if (m_loop_thread.joinable()) + m_loop_thread.join(); + + { + std::lock_guard guard(m_server_mutex); + m_listener.reset(); + m_listen_handlers.clear(); + m_clients.clear(); + } + + return llvm::Error::success(); +} + +llvm::Expected> +ProtocolServerMCP::HandleData(llvm::StringRef data) { + auto message = llvm::json::parse(/*JSON=*/data); + if (!message) + return message.takeError(); + + if (const protocol::Request *request = + std::get_if(&(*message))) { + llvm::Expected response = Handle(*request); + + // Handle failures by converting them into an Error message. + if (!response) { + protocol::Error protocol_error; + llvm::handleAllErrors( + response.takeError(), + [&](const MCPError &err) { protocol_error = err.toProtcolError(); }, + [&](const llvm::ErrorInfoBase &err) { + protocol_error.error.code = MCPError::kInternalError; + protocol_error.error.message = err.message(); + }); + protocol_error.id = request->id; + return protocol_error; + } + + return *response; + } + + if (const protocol::Notification *notification = + std::get_if(&(*message))) { + Handle(*notification); + return std::nullopt; + } + + if (std::get_if(&(*message))) + return llvm::createStringError("unexpected MCP message: error"); + + if (std::get_if(&(*message))) + return llvm::createStringError("unexpected MCP message: response"); + + llvm_unreachable("all message types handled"); +} + +protocol::Capabilities ProtocolServerMCP::GetCapabilities() { + protocol::Capabilities capabilities; + capabilities.tools.listChanged = true; + // FIXME: Support sending notifications when a debugger/target are + // added/removed. + capabilities.resources.listChanged = false; + return capabilities; +} + +void ProtocolServerMCP::AddTool(std::unique_ptr tool) { + std::lock_guard guard(m_server_mutex); + + if (!tool) + return; + m_tools[tool->GetName()] = std::move(tool); +} + +void ProtocolServerMCP::AddResourceProvider( + std::unique_ptr resource_provider) { + std::lock_guard guard(m_server_mutex); + + if (!resource_provider) + return; + m_resource_providers.push_back(std::move(resource_provider)); +} + +void ProtocolServerMCP::AddRequestHandler(llvm::StringRef method, + RequestHandler handler) { + std::lock_guard guard(m_server_mutex); + m_request_handlers[method] = std::move(handler); +} + +void ProtocolServerMCP::AddNotificationHandler(llvm::StringRef method, + NotificationHandler handler) { + std::lock_guard guard(m_server_mutex); + m_notification_handlers[method] = std::move(handler); +} + +llvm::Expected +ProtocolServerMCP::InitializeHandler(const protocol::Request &request) { + protocol::Response response; + response.result.emplace(llvm::json::Object{ + {"protocolVersion", protocol::kVersion}, + {"capabilities", GetCapabilities()}, + {"serverInfo", + llvm::json::Object{{"name", kName}, {"version", kVersion}}}}); + return response; +} + +llvm::Expected +ProtocolServerMCP::ToolsListHandler(const protocol::Request &request) { + protocol::Response response; + + llvm::json::Array tools; + for (const auto &tool : m_tools) + tools.emplace_back(toJSON(tool.second->GetDefinition())); + + response.result.emplace(llvm::json::Object{{"tools", std::move(tools)}}); + + return response; +} + +llvm::Expected +ProtocolServerMCP::ToolsCallHandler(const protocol::Request &request) { + protocol::Response response; + + if (!request.params) + return llvm::createStringError("no tool parameters"); + + const json::Object *param_obj = request.params->getAsObject(); + if (!param_obj) + return llvm::createStringError("no tool parameters"); + + const json::Value *name = param_obj->get("name"); + if (!name) + return llvm::createStringError("no tool name"); + + llvm::StringRef tool_name = name->getAsString().value_or(""); + if (tool_name.empty()) + return llvm::createStringError("no tool name"); + + auto it = m_tools.find(tool_name); + if (it == m_tools.end()) + return llvm::createStringError(llvm::formatv("no tool \"{0}\"", tool_name)); + + protocol::ToolArguments tool_args; + if (const json::Value *args = param_obj->get("arguments")) + tool_args = *args; + + llvm::Expected text_result = + it->second->Call(tool_args); + if (!text_result) + return text_result.takeError(); + + response.result.emplace(toJSON(*text_result)); + + return response; +} + +llvm::Expected +ProtocolServerMCP::ResourcesListHandler(const protocol::Request &request) { + protocol::Response response; + + llvm::json::Array resources; + + std::lock_guard guard(m_server_mutex); + for (std::unique_ptr &resource_provider_up : + m_resource_providers) { + for (const protocol::Resource &resource : + resource_provider_up->GetResources()) + resources.push_back(resource); + } + response.result.emplace( + llvm::json::Object{{"resources", std::move(resources)}}); + + return response; +} + +llvm::Expected +ProtocolServerMCP::ResourcesReadHandler(const protocol::Request &request) { + protocol::Response response; + + if (!request.params) + return llvm::createStringError("no resource parameters"); + + const json::Object *param_obj = request.params->getAsObject(); + if (!param_obj) + return llvm::createStringError("no resource parameters"); + + const json::Value *uri = param_obj->get("uri"); + if (!uri) + return llvm::createStringError("no resource uri"); + + llvm::StringRef uri_str = uri->getAsString().value_or(""); + if (uri_str.empty()) + return llvm::createStringError("no resource uri"); + + std::lock_guard guard(m_server_mutex); + for (std::unique_ptr &resource_provider_up : + m_resource_providers) { + llvm::Expected result = + resource_provider_up->ReadResource(uri_str); + if (result.errorIsA()) { + llvm::consumeError(result.takeError()); + continue; + } + if (!result) + return result.takeError(); + + protocol::Response response; + response.result.emplace(std::move(*result)); + return response; + } + + return make_error( + llvm::formatv("no resource handler for uri: {0}", uri_str).str(), + MCPError::kResourceNotFound); +} diff --git a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h new file mode 100644 index 0000000000000..e273f6e2a8d37 --- /dev/null +++ b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h @@ -0,0 +1,108 @@ +//===- ProtocolServerMCP.h ------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLDB_PLUGINS_PROTOCOL_MCP_PROTOCOLSERVERMCP_H +#define LLDB_PLUGINS_PROTOCOL_MCP_PROTOCOLSERVERMCP_H + +#include "Protocol.h" +#include "Resource.h" +#include "Tool.h" +#include "lldb/Core/ProtocolServer.h" +#include "lldb/Host/MainLoop.h" +#include "lldb/Host/Socket.h" +#include "llvm/ADT/StringMap.h" +#include + +namespace lldb_private::mcp { + +class ProtocolServerMCP : public ProtocolServer { +public: + ProtocolServerMCP(); + virtual ~ProtocolServerMCP() override; + + virtual llvm::Error Start(ProtocolServer::Connection connection) override; + virtual llvm::Error Stop() override; + + static void Initialize(); + static void Terminate(); + + static llvm::StringRef GetPluginNameStatic() { return "MCP"; } + static llvm::StringRef GetPluginDescriptionStatic(); + + static lldb::ProtocolServerUP CreateInstance(); + + llvm::StringRef GetPluginName() override { return GetPluginNameStatic(); } + + Socket *GetSocket() const override { return m_listener.get(); } + +protected: + using RequestHandler = std::function( + const protocol::Request &)>; + using NotificationHandler = + std::function; + + void AddTool(std::unique_ptr tool); + void AddResourceProvider(std::unique_ptr resource_provider); + + void AddRequestHandler(llvm::StringRef method, RequestHandler handler); + void AddNotificationHandler(llvm::StringRef method, + NotificationHandler handler); + +private: + void AcceptCallback(std::unique_ptr socket); + + llvm::Expected> + HandleData(llvm::StringRef data); + + llvm::Expected Handle(protocol::Request request); + void Handle(protocol::Notification notification); + + llvm::Expected + InitializeHandler(const protocol::Request &); + + llvm::Expected + ToolsListHandler(const protocol::Request &); + llvm::Expected + ToolsCallHandler(const protocol::Request &); + + llvm::Expected + ResourcesListHandler(const protocol::Request &); + llvm::Expected + ResourcesReadHandler(const protocol::Request &); + + protocol::Capabilities GetCapabilities(); + + llvm::StringLiteral kName = "lldb-mcp"; + llvm::StringLiteral kVersion = "0.1.0"; + + bool m_running = false; + + MainLoop m_loop; + std::thread m_loop_thread; + + std::unique_ptr m_listener; + std::vector m_listen_handlers; + + struct Client { + lldb::IOObjectSP io_sp; + MainLoopBase::ReadHandleUP read_handle_up; + std::string buffer; + }; + llvm::Error ReadCallback(Client &client); + std::vector> m_clients; + + std::mutex m_server_mutex; + llvm::StringMap> m_tools; + std::vector> m_resource_providers; + + llvm::StringMap m_request_handlers; + llvm::StringMap m_notification_handlers; +}; +} // namespace lldb_private::mcp + +#endif diff --git a/lldb/source/Plugins/Protocol/MCP/Resource.cpp b/lldb/source/Plugins/Protocol/MCP/Resource.cpp new file mode 100644 index 0000000000000..d75d5b6dd6a41 --- /dev/null +++ b/lldb/source/Plugins/Protocol/MCP/Resource.cpp @@ -0,0 +1,217 @@ +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "Resource.h" +#include "MCPError.h" +#include "lldb/Core/Debugger.h" +#include "lldb/Core/Module.h" +#include "lldb/Target/Platform.h" + +using namespace lldb_private::mcp; + +namespace { +struct DebuggerResource { + uint64_t debugger_id = 0; + std::string name; + uint64_t num_targets = 0; +}; + +llvm::json::Value toJSON(const DebuggerResource &DR) { + llvm::json::Object Result{{"debugger_id", DR.debugger_id}, + {"num_targets", DR.num_targets}}; + if (!DR.name.empty()) + Result.insert({"name", DR.name}); + return Result; +} + +struct TargetResource { + size_t debugger_id = 0; + size_t target_idx = 0; + bool selected = false; + bool dummy = false; + std::string arch; + std::string path; + std::string platform; +}; + +llvm::json::Value toJSON(const TargetResource &TR) { + llvm::json::Object Result{{"debugger_id", TR.debugger_id}, + {"target_idx", TR.target_idx}, + {"selected", TR.selected}, + {"dummy", TR.dummy}}; + if (!TR.arch.empty()) + Result.insert({"arch", TR.arch}); + if (!TR.path.empty()) + Result.insert({"path", TR.path}); + if (!TR.platform.empty()) + Result.insert({"platform", TR.platform}); + return Result; +} +} // namespace + +static constexpr llvm::StringLiteral kMimeTypeJSON = "application/json"; + +template +static llvm::Error createStringError(const char *format, Args &&...args) { + return llvm::createStringError( + llvm::formatv(format, std::forward(args)...).str()); +} + +static llvm::Error createUnsupportedURIError(llvm::StringRef uri) { + return llvm::make_error(uri.str()); +} + +protocol::Resource +DebuggerResourceProvider::GetDebuggerResource(Debugger &debugger) { + const lldb::user_id_t debugger_id = debugger.GetID(); + + protocol::Resource resource; + resource.uri = llvm::formatv("lldb://debugger/{0}", debugger_id); + resource.name = debugger.GetInstanceName(); + resource.description = + llvm::formatv("Information about debugger instance {0}: {1}", debugger_id, + debugger.GetInstanceName()); + resource.mimeType = kMimeTypeJSON; + return resource; +} + +protocol::Resource +DebuggerResourceProvider::GetTargetResource(size_t target_idx, Target &target) { + const size_t debugger_id = target.GetDebugger().GetID(); + + std::string target_name = llvm::formatv("target {0}", target_idx); + + if (Module *exe_module = target.GetExecutableModulePointer()) + target_name = exe_module->GetFileSpec().GetFilename().GetString(); + + protocol::Resource resource; + resource.uri = + llvm::formatv("lldb://debugger/{0}/target/{1}", debugger_id, target_idx); + resource.name = target_name; + resource.description = + llvm::formatv("Information about target {0} in debugger instance {1}", + target_idx, debugger_id); + resource.mimeType = kMimeTypeJSON; + return resource; +} + +std::vector DebuggerResourceProvider::GetResources() const { + std::vector resources; + + const size_t num_debuggers = Debugger::GetNumDebuggers(); + for (size_t i = 0; i < num_debuggers; ++i) { + lldb::DebuggerSP debugger_sp = Debugger::GetDebuggerAtIndex(i); + if (!debugger_sp) + continue; + resources.emplace_back(GetDebuggerResource(*debugger_sp)); + + TargetList &target_list = debugger_sp->GetTargetList(); + const size_t num_targets = target_list.GetNumTargets(); + for (size_t j = 0; j < num_targets; ++j) { + lldb::TargetSP target_sp = target_list.GetTargetAtIndex(j); + if (!target_sp) + continue; + resources.emplace_back(GetTargetResource(j, *target_sp)); + } + } + + return resources; +} + +llvm::Expected +DebuggerResourceProvider::ReadResource(llvm::StringRef uri) const { + + auto [protocol, path] = uri.split("://"); + + if (protocol != "lldb") + return createUnsupportedURIError(uri); + + llvm::SmallVector components; + path.split(components, '/'); + + if (components.size() < 2) + return createUnsupportedURIError(uri); + + if (components[0] != "debugger") + return createUnsupportedURIError(uri); + + size_t debugger_idx; + if (components[1].getAsInteger(0, debugger_idx)) + return createStringError("invalid debugger id '{0}': {1}", components[1], + path); + + if (components.size() > 3) { + if (components[2] != "target") + return createUnsupportedURIError(uri); + + size_t target_idx; + if (components[3].getAsInteger(0, target_idx)) + return createStringError("invalid target id '{0}': {1}", components[3], + path); + + return ReadTargetResource(uri, debugger_idx, target_idx); + } + + return ReadDebuggerResource(uri, debugger_idx); +} + +llvm::Expected +DebuggerResourceProvider::ReadDebuggerResource(llvm::StringRef uri, + lldb::user_id_t debugger_id) { + lldb::DebuggerSP debugger_sp = Debugger::FindDebuggerWithID(debugger_id); + if (!debugger_sp) + return createStringError("invalid debugger id: {0}", debugger_id); + + DebuggerResource debugger_resource; + debugger_resource.debugger_id = debugger_id; + debugger_resource.name = debugger_sp->GetInstanceName(); + debugger_resource.num_targets = debugger_sp->GetTargetList().GetNumTargets(); + + protocol::ResourceContents contents; + contents.uri = uri; + contents.mimeType = kMimeTypeJSON; + contents.text = llvm::formatv("{0}", toJSON(debugger_resource)); + + protocol::ResourceResult result; + result.contents.push_back(contents); + return result; +} + +llvm::Expected +DebuggerResourceProvider::ReadTargetResource(llvm::StringRef uri, + lldb::user_id_t debugger_id, + size_t target_idx) { + + lldb::DebuggerSP debugger_sp = Debugger::FindDebuggerWithID(debugger_id); + if (!debugger_sp) + return createStringError("invalid debugger id: {0}", debugger_id); + + TargetList &target_list = debugger_sp->GetTargetList(); + lldb::TargetSP target_sp = target_list.GetTargetAtIndex(target_idx); + if (!target_sp) + return createStringError("invalid target idx: {0}", target_idx); + + TargetResource target_resource; + target_resource.debugger_id = debugger_id; + target_resource.target_idx = target_idx; + target_resource.arch = target_sp->GetArchitecture().GetTriple().str(); + target_resource.dummy = target_sp->IsDummyTarget(); + target_resource.selected = target_sp == debugger_sp->GetSelectedTarget(); + + if (Module *exe_module = target_sp->GetExecutableModulePointer()) + target_resource.path = exe_module->GetFileSpec().GetPath(); + if (lldb::PlatformSP platform_sp = target_sp->GetPlatform()) + target_resource.platform = platform_sp->GetName(); + + protocol::ResourceContents contents; + contents.uri = uri; + contents.mimeType = kMimeTypeJSON; + contents.text = llvm::formatv("{0}", toJSON(target_resource)); + + protocol::ResourceResult result; + result.contents.push_back(contents); + return result; +} diff --git a/lldb/source/Plugins/Protocol/MCP/Resource.h b/lldb/source/Plugins/Protocol/MCP/Resource.h new file mode 100644 index 0000000000000..5ac38e7e878ff --- /dev/null +++ b/lldb/source/Plugins/Protocol/MCP/Resource.h @@ -0,0 +1,51 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLDB_PLUGINS_PROTOCOL_MCP_RESOURCE_H +#define LLDB_PLUGINS_PROTOCOL_MCP_RESOURCE_H + +#include "Protocol.h" +#include "lldb/lldb-private.h" +#include + +namespace lldb_private::mcp { + +class ResourceProvider { +public: + ResourceProvider() = default; + virtual ~ResourceProvider() = default; + + virtual std::vector GetResources() const = 0; + virtual llvm::Expected + ReadResource(llvm::StringRef uri) const = 0; +}; + +class DebuggerResourceProvider : public ResourceProvider { +public: + using ResourceProvider::ResourceProvider; + virtual ~DebuggerResourceProvider() = default; + + virtual std::vector GetResources() const override; + virtual llvm::Expected + ReadResource(llvm::StringRef uri) const override; + +private: + static protocol::Resource GetDebuggerResource(Debugger &debugger); + static protocol::Resource GetTargetResource(size_t target_idx, + Target &target); + + static llvm::Expected + ReadDebuggerResource(llvm::StringRef uri, lldb::user_id_t debugger_id); + static llvm::Expected + ReadTargetResource(llvm::StringRef uri, lldb::user_id_t debugger_id, + size_t target_idx); +}; + +} // namespace lldb_private::mcp + +#endif diff --git a/lldb/source/Plugins/Protocol/MCP/Tool.cpp b/lldb/source/Plugins/Protocol/MCP/Tool.cpp new file mode 100644 index 0000000000000..bbc19a1e51942 --- /dev/null +++ b/lldb/source/Plugins/Protocol/MCP/Tool.cpp @@ -0,0 +1,103 @@ +//===- Tool.cpp -----------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "Tool.h" +#include "lldb/Core/Module.h" +#include "lldb/Interpreter/CommandInterpreter.h" +#include "lldb/Interpreter/CommandReturnObject.h" + +using namespace lldb_private::mcp; +using namespace llvm; + +namespace { +struct CommandToolArguments { + uint64_t debugger_id; + std::string arguments; +}; + +bool fromJSON(const llvm::json::Value &V, CommandToolArguments &A, + llvm::json::Path P) { + llvm::json::ObjectMapper O(V, P); + return O && O.map("debugger_id", A.debugger_id) && + O.mapOptional("arguments", A.arguments); +} + +/// Helper function to create a TextResult from a string output. +static lldb_private::mcp::protocol::TextResult +createTextResult(std::string output, bool is_error = false) { + lldb_private::mcp::protocol::TextResult text_result; + text_result.content.emplace_back( + lldb_private::mcp::protocol::TextContent{{std::move(output)}}); + text_result.isError = is_error; + return text_result; +} + +} // namespace + +Tool::Tool(std::string name, std::string description) + : m_name(std::move(name)), m_description(std::move(description)) {} + +protocol::ToolDefinition Tool::GetDefinition() const { + protocol::ToolDefinition definition; + definition.name = m_name; + definition.description = m_description; + + if (std::optional input_schema = GetSchema()) + definition.inputSchema = *input_schema; + + return definition; +} + +llvm::Expected +CommandTool::Call(const protocol::ToolArguments &args) { + if (!std::holds_alternative(args)) + return createStringError("CommandTool requires arguments"); + + json::Path::Root root; + + CommandToolArguments arguments; + if (!fromJSON(std::get(args), arguments, root)) + return root.getError(); + + lldb::DebuggerSP debugger_sp = + Debugger::FindDebuggerWithID(arguments.debugger_id); + if (!debugger_sp) + return createStringError( + llvm::formatv("no debugger with id {0}", arguments.debugger_id)); + + // FIXME: Disallow certain commands and their aliases. + CommandReturnObject result(/*colors=*/false); + debugger_sp->GetCommandInterpreter().HandleCommand( + arguments.arguments.c_str(), eLazyBoolYes, result); + + std::string output; + llvm::StringRef output_str = result.GetOutputString(); + if (!output_str.empty()) + output += output_str.str(); + + std::string err_str = result.GetErrorString(); + if (!err_str.empty()) { + if (!output.empty()) + output += '\n'; + output += err_str; + } + + return createTextResult(output, !result.Succeeded()); +} + +std::optional CommandTool::GetSchema() const { + llvm::json::Object id_type{{"type", "number"}}; + llvm::json::Object str_type{{"type", "string"}}; + llvm::json::Object properties{{"debugger_id", std::move(id_type)}, + {"arguments", std::move(str_type)}}; + llvm::json::Array required{"debugger_id"}; + llvm::json::Object schema{{"type", "object"}, + {"properties", std::move(properties)}, + {"required", std::move(required)}}; + return schema; +} diff --git a/lldb/source/Plugins/Protocol/MCP/Tool.h b/lldb/source/Plugins/Protocol/MCP/Tool.h new file mode 100644 index 0000000000000..d0f639adad24e --- /dev/null +++ b/lldb/source/Plugins/Protocol/MCP/Tool.h @@ -0,0 +1,53 @@ +//===- Tool.h -------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLDB_PLUGINS_PROTOCOL_MCP_TOOL_H +#define LLDB_PLUGINS_PROTOCOL_MCP_TOOL_H + +#include "Protocol.h" +#include "lldb/Core/Debugger.h" +#include "llvm/Support/JSON.h" +#include + +namespace lldb_private::mcp { + +class Tool { +public: + Tool(std::string name, std::string description); + virtual ~Tool() = default; + + virtual llvm::Expected + Call(const protocol::ToolArguments &args) = 0; + + virtual std::optional GetSchema() const { + return llvm::json::Object{{"type", "object"}}; + } + + protocol::ToolDefinition GetDefinition() const; + + const std::string &GetName() { return m_name; } + +private: + std::string m_name; + std::string m_description; +}; + +class CommandTool : public mcp::Tool { +public: + using mcp::Tool::Tool; + ~CommandTool() = default; + + virtual llvm::Expected + Call(const protocol::ToolArguments &args) override; + + virtual std::optional GetSchema() const override; +}; + +} // namespace lldb_private::mcp + +#endif diff --git a/lldb/source/Target/Process.cpp b/lldb/source/Target/Process.cpp index 8be1ecffead0f..19b21e201afc4 100644 --- a/lldb/source/Target/Process.cpp +++ b/lldb/source/Target/Process.cpp @@ -4748,15 +4748,16 @@ class IOHandlerProcessSTDIO : public IOHandler { } if (select_helper.FDIsSetRead(pipe_read_fd)) { - size_t bytes_read; // Consume the interrupt byte - Status error = m_pipe.Read(&ch, 1, bytes_read); - if (error.Success()) { + if (llvm::Expected bytes_read = m_pipe.Read(&ch, 1)) { if (ch == 'q') break; if (ch == 'i') if (StateIsRunningState(m_process->GetState())) m_process->SendAsyncInterrupt(); + } else { + LLDB_LOG_ERROR(GetLog(LLDBLog::Process), bytes_read.takeError(), + "Pipe read failed: {0}"); } } } @@ -4780,8 +4781,10 @@ class IOHandlerProcessSTDIO : public IOHandler { // deadlocking when the pipe gets fed up and blocks until data is consumed. if (m_is_running) { char ch = 'q'; // Send 'q' for quit - size_t bytes_written = 0; - m_pipe.Write(&ch, 1, bytes_written); + if (llvm::Error err = m_pipe.Write(&ch, 1).takeError()) { + LLDB_LOG_ERROR(GetLog(LLDBLog::Process), std::move(err), + "Pipe write failed: {0}"); + } } } @@ -4793,9 +4796,7 @@ class IOHandlerProcessSTDIO : public IOHandler { // m_process->SendAsyncInterrupt() from a much safer location in code. if (m_active) { char ch = 'i'; // Send 'i' for interrupt - size_t bytes_written = 0; - Status result = m_pipe.Write(&ch, 1, bytes_written); - return result.Success(); + return !errorToBool(m_pipe.Write(&ch, 1).takeError()); } else { // This IOHandler might be pushed on the stack, but not being run // currently so do the right thing if we aren't actively watching for diff --git a/lldb/tools/lldb-server/lldb-gdbserver.cpp b/lldb/tools/lldb-server/lldb-gdbserver.cpp index 563284730bc70..1ecbdad3ca5c0 100644 --- a/lldb/tools/lldb-server/lldb-gdbserver.cpp +++ b/lldb/tools/lldb-server/lldb-gdbserver.cpp @@ -167,27 +167,35 @@ void handle_launch(GDBRemoteCommunicationServerLLGS &gdb_server, } } -Status writeSocketIdToPipe(Pipe &port_pipe, llvm::StringRef socket_id) { - size_t bytes_written = 0; - // Write the port number as a C string with the NULL terminator. - return port_pipe.Write(socket_id.data(), socket_id.size() + 1, bytes_written); +static Status writeSocketIdToPipe(Pipe &port_pipe, + const std::string &socket_id) { + // NB: Include the nul character at the end. + llvm::StringRef buf(socket_id.data(), socket_id.size() + 1); + while (!buf.empty()) { + if (llvm::Expected written = + port_pipe.Write(buf.data(), buf.size())) + buf = buf.drop_front(*written); + else + return Status::FromError(written.takeError()); + } + return Status(); } Status writeSocketIdToPipe(const char *const named_pipe_path, llvm::StringRef socket_id) { Pipe port_name_pipe; // Wait for 10 seconds for pipe to be opened. - auto error = port_name_pipe.OpenAsWriterWithTimeout(named_pipe_path, false, - std::chrono::seconds{10}); - if (error.Fail()) - return error; - return writeSocketIdToPipe(port_name_pipe, socket_id); + if (llvm::Error err = port_name_pipe.OpenAsWriter(named_pipe_path, false, + std::chrono::seconds{10})) + return Status::FromError(std::move(err)); + + return writeSocketIdToPipe(port_name_pipe, socket_id.str()); } Status writeSocketIdToPipe(lldb::pipe_t unnamed_pipe, llvm::StringRef socket_id) { Pipe port_pipe{LLDB_INVALID_PIPE, unnamed_pipe}; - return writeSocketIdToPipe(port_pipe, socket_id); + return writeSocketIdToPipe(port_pipe, socket_id.str()); } void ConnectToRemote(MainLoop &mainloop, diff --git a/lldb/unittests/CMakeLists.txt b/lldb/unittests/CMakeLists.txt index 926f8a8602472..95af91ea05883 100644 --- a/lldb/unittests/CMakeLists.txt +++ b/lldb/unittests/CMakeLists.txt @@ -84,6 +84,10 @@ add_subdirectory(Utility) add_subdirectory(Thread) add_subdirectory(ValueObject) +if(LLDB_ENABLE_PROTOCOL_SERVERS) + add_subdirectory(Protocol) +endif() + if(LLDB_CAN_USE_DEBUGSERVER AND LLDB_TOOL_DEBUGSERVER_BUILD AND NOT LLDB_USE_SYSTEM_DEBUGSERVER) add_subdirectory(debugserver) endif() diff --git a/lldb/unittests/Core/SwiftDemanglingPartsTest.cpp b/lldb/unittests/Core/SwiftDemanglingPartsTest.cpp index 4d053504f63ad..1ac35425e039d 100644 --- a/lldb/unittests/Core/SwiftDemanglingPartsTest.cpp +++ b/lldb/unittests/Core/SwiftDemanglingPartsTest.cpp @@ -6,15 +6,17 @@ // //===----------------------------------------------------------------------===// -#include "Plugins/Language/Swift/SwiftMangled.h" -#include "Plugins/LanguageRuntime/Swift/SwiftLanguageRuntime.h" #include "TestingSupport/TestUtilities.h" - #include "lldb/Core/DemangledNameInfo.h" #include "lldb/Core/Mangled.h" - +#include "lldb/Host/Config.h" #include "gtest/gtest.h" +#ifdef LLDB_ENABLE_SWIFT + +#include "Plugins/Language/Swift/SwiftMangled.h" +#include "Plugins/LanguageRuntime/Swift/SwiftLanguageRuntime.h" + using namespace lldb; using namespace lldb_private; @@ -1261,4 +1263,6 @@ TEST_P(SwiftDemanglingPartsTestFixture, SwiftDemanglingParts) { INSTANTIATE_TEST_SUITE_P( SwiftDemanglingPartsTests, SwiftDemanglingPartsTestFixture, - ::testing::ValuesIn(g_swift_demangling_parts_test_cases)); \ No newline at end of file + ::testing::ValuesIn(g_swift_demangling_parts_test_cases)); + +#endif diff --git a/lldb/unittests/DAP/TestBase.cpp b/lldb/unittests/DAP/TestBase.cpp new file mode 100644 index 0000000000000..d5d36158d68e0 --- /dev/null +++ b/lldb/unittests/DAP/TestBase.cpp @@ -0,0 +1,129 @@ +//===-- TestBase.cpp ------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "TestBase.h" +#include "Protocol/ProtocolBase.h" +#include "TestingSupport/TestUtilities.h" +#include "lldb/API/SBDefines.h" +#include "lldb/API/SBStructuredData.h" +#include "lldb/Host/File.h" +#include "lldb/Host/Pipe.h" +#include "lldb/lldb-forward.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Testing/Support/Error.h" +#include "gtest/gtest.h" +#include + +using namespace llvm; +using namespace lldb; +using namespace lldb_dap; +using namespace lldb_dap::protocol; +using namespace lldb_dap_tests; +using lldb_private::File; +using lldb_private::NativeFile; +using lldb_private::Pipe; + +void TransportBase::SetUp() { + PipePairTest::SetUp(); + to_dap = std::make_unique( + "to_dap", nullptr, + std::make_shared(input.GetReadFileDescriptor(), + File::eOpenOptionReadOnly, + NativeFile::Unowned), + std::make_shared(output.GetWriteFileDescriptor(), + File::eOpenOptionWriteOnly, + NativeFile::Unowned)); + from_dap = std::make_unique( + "from_dap", nullptr, + std::make_shared(output.GetReadFileDescriptor(), + File::eOpenOptionReadOnly, + NativeFile::Unowned), + std::make_shared(input.GetWriteFileDescriptor(), + File::eOpenOptionWriteOnly, + NativeFile::Unowned)); +} + +void DAPTestBase::SetUp() { + TransportBase::SetUp(); + dap = std::make_unique( + /*log=*/nullptr, + /*default_repl_mode=*/ReplMode::Auto, + /*pre_init_commands=*/std::vector(), + /*transport=*/*to_dap); +} + +void DAPTestBase::TearDown() { + if (core) + ASSERT_THAT_ERROR(core->discard(), Succeeded()); + if (binary) + ASSERT_THAT_ERROR(binary->discard(), Succeeded()); +} + +void DAPTestBase::SetUpTestSuite() { + lldb::SBError error = SBDebugger::InitializeWithErrorHandling(); + EXPECT_TRUE(error.Success()); +} +void DAPTestBase::TeatUpTestSuite() { SBDebugger::Terminate(); } + +bool DAPTestBase::GetDebuggerSupportsTarget(llvm::StringRef platform) { + EXPECT_TRUE(dap->debugger); + + lldb::SBStructuredData data = dap->debugger.GetBuildConfiguration() + .GetValueForKey("targets") + .GetValueForKey("value"); + for (size_t i = 0; i < data.GetSize(); i++) { + char buf[100] = {0}; + size_t size = data.GetItemAtIndex(i).GetStringValue(buf, sizeof(buf)); + if (llvm::StringRef(buf, size) == platform) + return true; + } + + return false; +} + +void DAPTestBase::CreateDebugger() { + dap->debugger = lldb::SBDebugger::Create(); + ASSERT_TRUE(dap->debugger); +} + +void DAPTestBase::LoadCore() { + ASSERT_TRUE(dap->debugger); + llvm::Expected binary_yaml = + lldb_private::TestFile::fromYamlFile(k_linux_binary); + ASSERT_THAT_EXPECTED(binary_yaml, Succeeded()); + llvm::Expected binary_file = + binary_yaml->writeToTemporaryFile(); + ASSERT_THAT_EXPECTED(binary_file, Succeeded()); + binary = std::move(*binary_file); + dap->target = dap->debugger.CreateTarget(binary->TmpName.data()); + ASSERT_TRUE(dap->target); + llvm::Expected core_yaml = + lldb_private::TestFile::fromYamlFile(k_linux_core); + ASSERT_THAT_EXPECTED(core_yaml, Succeeded()); + llvm::Expected core_file = + core_yaml->writeToTemporaryFile(); + ASSERT_THAT_EXPECTED(core_file, Succeeded()); + this->core = std::move(*core_file); + SBProcess process = dap->target.LoadCore(this->core->TmpName.data()); + ASSERT_TRUE(process); +} + +std::vector DAPTestBase::DrainOutput() { + std::vector msgs; + output.CloseWriteFileDescriptor(); + while (true) { + Expected next = + from_dap->Read(std::chrono::milliseconds(1)); + if (!next) { + consumeError(next.takeError()); + break; + } + msgs.push_back(*next); + } + return msgs; +} diff --git a/lldb/unittests/Host/CMakeLists.txt b/lldb/unittests/Host/CMakeLists.txt index e2cb0a9e5713a..7c7fabf9716e0 100644 --- a/lldb/unittests/Host/CMakeLists.txt +++ b/lldb/unittests/Host/CMakeLists.txt @@ -13,6 +13,7 @@ set (FILES HostInfoTest.cpp HostTest.cpp MainLoopTest.cpp + JSONTransportTest.cpp NativeProcessProtocolTest.cpp PipeTest.cpp ProcessLaunchInfoTest.cpp diff --git a/lldb/unittests/Host/JSONTransportTest.cpp b/lldb/unittests/Host/JSONTransportTest.cpp new file mode 100644 index 0000000000000..d54d121500be0 --- /dev/null +++ b/lldb/unittests/Host/JSONTransportTest.cpp @@ -0,0 +1,176 @@ +//===-- JSONTransportTest.cpp ---------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "lldb/Host/JSONTransport.h" +#include "TestingSupport/Host/PipeTestUtilities.h" +#include "lldb/Host/File.h" + +using namespace llvm; +using namespace lldb_private; + +namespace { +template class JSONTransportTest : public PipePairTest { +protected: + std::unique_ptr transport; + + void SetUp() override { + PipePairTest::SetUp(); + transport = std::make_unique( + std::make_shared(input.GetReadFileDescriptor(), + File::eOpenOptionReadOnly, + NativeFile::Unowned), + std::make_shared(output.GetWriteFileDescriptor(), + File::eOpenOptionWriteOnly, + NativeFile::Unowned)); + } +}; + +class HTTPDelimitedJSONTransportTest + : public JSONTransportTest { +public: + using JSONTransportTest::JSONTransportTest; +}; + +class JSONRPCTransportTest : public JSONTransportTest { +public: + using JSONTransportTest::JSONTransportTest; +}; + +struct JSONTestType { + std::string str; +}; + +llvm::json::Value toJSON(const JSONTestType &T) { + return llvm::json::Object{{"str", T.str}}; +} + +bool fromJSON(const llvm::json::Value &V, JSONTestType &T, llvm::json::Path P) { + llvm::json::ObjectMapper O(V, P); + return O && O.map("str", T.str); +} +} // namespace + +TEST_F(HTTPDelimitedJSONTransportTest, MalformedRequests) { + std::string malformed_header = "COnTent-LenGth: -1{}\r\n\r\nnotjosn"; + ASSERT_THAT_EXPECTED( + input.Write(malformed_header.data(), malformed_header.size()), + Succeeded()); + ASSERT_THAT_EXPECTED( + transport->Read(std::chrono::milliseconds(1)), + FailedWithMessage( + "expected 'Content-Length: ' and got 'COnTent-LenGth: '")); +} + +TEST_F(HTTPDelimitedJSONTransportTest, Read) { + std::string json = R"json({"str": "foo"})json"; + std::string message = + formatv("Content-Length: {0}\r\n\r\n{1}", json.size(), json).str(); + ASSERT_THAT_EXPECTED(input.Write(message.data(), message.size()), + Succeeded()); + ASSERT_THAT_EXPECTED( + transport->Read(std::chrono::milliseconds(1)), + HasValue(testing::FieldsAre(/*str=*/"foo"))); +} + +TEST_F(HTTPDelimitedJSONTransportTest, ReadWithEOF) { + input.CloseWriteFileDescriptor(); + ASSERT_THAT_EXPECTED( + transport->Read(std::chrono::milliseconds(1)), + Failed()); +} + +TEST_F(HTTPDelimitedJSONTransportTest, ReadAfterClosed) { + input.CloseReadFileDescriptor(); + ASSERT_THAT_EXPECTED( + transport->Read(std::chrono::milliseconds(1)), + llvm::Failed()); +} + +TEST_F(HTTPDelimitedJSONTransportTest, InvalidTransport) { + transport = std::make_unique(nullptr, nullptr); + ASSERT_THAT_EXPECTED( + transport->Read(std::chrono::milliseconds(1)), + Failed()); +} + +TEST_F(HTTPDelimitedJSONTransportTest, Write) { + ASSERT_THAT_ERROR(transport->Write(JSONTestType{"foo"}), Succeeded()); + output.CloseWriteFileDescriptor(); + char buf[1024]; + Expected bytes_read = + output.Read(buf, sizeof(buf), std::chrono::milliseconds(1)); + ASSERT_THAT_EXPECTED(bytes_read, Succeeded()); + ASSERT_EQ(StringRef(buf, *bytes_read), StringRef("Content-Length: 13\r\n\r\n" + R"json({"str":"foo"})json")); +} + +TEST_F(JSONRPCTransportTest, MalformedRequests) { + std::string malformed_header = "notjson\n"; + ASSERT_THAT_EXPECTED( + input.Write(malformed_header.data(), malformed_header.size()), + Succeeded()); + ASSERT_THAT_EXPECTED( + transport->Read(std::chrono::milliseconds(1)), + llvm::Failed()); +} + +TEST_F(JSONRPCTransportTest, Read) { + std::string json = R"json({"str": "foo"})json"; + std::string message = formatv("{0}\n", json).str(); + ASSERT_THAT_EXPECTED(input.Write(message.data(), message.size()), + Succeeded()); + ASSERT_THAT_EXPECTED( + transport->Read(std::chrono::milliseconds(1)), + HasValue(testing::FieldsAre(/*str=*/"foo"))); +} + +TEST_F(JSONRPCTransportTest, ReadWithEOF) { + input.CloseWriteFileDescriptor(); + ASSERT_THAT_EXPECTED( + transport->Read(std::chrono::milliseconds(1)), + Failed()); +} + +TEST_F(JSONRPCTransportTest, ReadAfterClosed) { + input.CloseReadFileDescriptor(); + ASSERT_THAT_EXPECTED( + transport->Read(std::chrono::milliseconds(1)), + llvm::Failed()); +} + +TEST_F(JSONRPCTransportTest, Write) { + ASSERT_THAT_ERROR(transport->Write(JSONTestType{"foo"}), Succeeded()); + output.CloseWriteFileDescriptor(); + char buf[1024]; + Expected bytes_read = + output.Read(buf, sizeof(buf), std::chrono::milliseconds(1)); + ASSERT_THAT_EXPECTED(bytes_read, Succeeded()); + ASSERT_EQ(StringRef(buf, *bytes_read), StringRef(R"json({"str":"foo"})json" + "\n")); +} + +TEST_F(JSONRPCTransportTest, InvalidTransport) { + transport = std::make_unique(nullptr, nullptr); + ASSERT_THAT_EXPECTED( + transport->Read(std::chrono::milliseconds(1)), + Failed()); +} + +#ifndef _WIN32 +TEST_F(HTTPDelimitedJSONTransportTest, ReadWithTimeout) { + ASSERT_THAT_EXPECTED( + transport->Read(std::chrono::milliseconds(1)), + Failed()); +} + +TEST_F(JSONRPCTransportTest, ReadWithTimeout) { + ASSERT_THAT_EXPECTED( + transport->Read(std::chrono::milliseconds(1)), + Failed()); +} +#endif diff --git a/lldb/unittests/Host/PipeTest.cpp b/lldb/unittests/Host/PipeTest.cpp index 506f3d225a21e..a3a492648def6 100644 --- a/lldb/unittests/Host/PipeTest.cpp +++ b/lldb/unittests/Host/PipeTest.cpp @@ -10,9 +10,13 @@ #include "TestingSupport/SubsystemRAII.h" #include "lldb/Host/FileSystem.h" #include "lldb/Host/HostInfo.h" +#include "llvm/Testing/Support/Error.h" #include "gtest/gtest.h" +#include #include +#include #include +#include #include using namespace lldb_private; @@ -85,57 +89,53 @@ TEST_F(PipeTest, WriteWithTimeout) { char *read_ptr = reinterpret_cast(read_buf.data()); size_t write_bytes = 0; size_t read_bytes = 0; - size_t num_bytes = 0; // Write to the pipe until it is full. while (write_bytes + write_chunk_size <= buf_size) { - Status error = - pipe.WriteWithTimeout(write_ptr + write_bytes, write_chunk_size, - std::chrono::milliseconds(10), num_bytes); - if (error.Fail()) + llvm::Expected num_bytes = + pipe.Write(write_ptr + write_bytes, write_chunk_size, + std::chrono::milliseconds(10)); + if (num_bytes) { + write_bytes += *num_bytes; + } else { + ASSERT_THAT_ERROR(num_bytes.takeError(), llvm::Failed()); break; // The write buffer is full. - write_bytes += num_bytes; + } } ASSERT_LE(write_bytes + write_chunk_size, buf_size) << "Pipe buffer larger than expected"; // Attempt a write with a long timeout. auto start_time = std::chrono::steady_clock::now(); - ASSERT_THAT_ERROR(pipe.WriteWithTimeout(write_ptr + write_bytes, - write_chunk_size, - std::chrono::seconds(2), num_bytes) - .ToError(), - llvm::Failed()); + // TODO: Assert a specific error (EAGAIN?) here. + ASSERT_THAT_EXPECTED(pipe.Write(write_ptr + write_bytes, write_chunk_size, + std::chrono::seconds(2)), + llvm::Failed()); auto dur = std::chrono::steady_clock::now() - start_time; ASSERT_GE(dur, std::chrono::seconds(2)); // Attempt a write with a short timeout. start_time = std::chrono::steady_clock::now(); - ASSERT_THAT_ERROR( - pipe.WriteWithTimeout(write_ptr + write_bytes, write_chunk_size, - std::chrono::milliseconds(200), num_bytes) - .ToError(), - llvm::Failed()); + ASSERT_THAT_EXPECTED(pipe.Write(write_ptr + write_bytes, write_chunk_size, + std::chrono::milliseconds(200)), + llvm::Failed()); dur = std::chrono::steady_clock::now() - start_time; ASSERT_GE(dur, std::chrono::milliseconds(200)); ASSERT_LT(dur, std::chrono::seconds(2)); // Drain the pipe. while (read_bytes < write_bytes) { - ASSERT_THAT_ERROR( - pipe.ReadWithTimeout(read_ptr + read_bytes, write_bytes - read_bytes, - std::chrono::milliseconds(10), num_bytes) - .ToError(), - llvm::Succeeded()); - read_bytes += num_bytes; + llvm::Expected num_bytes = + pipe.Read(read_ptr + read_bytes, write_bytes - read_bytes, + std::chrono::milliseconds(10)); + ASSERT_THAT_EXPECTED(num_bytes, llvm::Succeeded()); + read_bytes += *num_bytes; } // Be sure the pipe is empty. - ASSERT_THAT_ERROR(pipe.ReadWithTimeout(read_ptr + read_bytes, 100, - std::chrono::milliseconds(10), - num_bytes) - .ToError(), - llvm::Failed()); + ASSERT_THAT_EXPECTED( + pipe.Read(read_ptr + read_bytes, 100, std::chrono::milliseconds(10)), + llvm::Failed()); // Check that we got what we wrote. ASSERT_EQ(write_bytes, read_bytes); @@ -144,9 +144,56 @@ TEST_F(PipeTest, WriteWithTimeout) { read_buf.begin())); // Write to the pipe again and check that it succeeds. - ASSERT_THAT_ERROR(pipe.WriteWithTimeout(write_ptr, write_chunk_size, - std::chrono::milliseconds(10), - num_bytes) - .ToError(), - llvm::Succeeded()); + ASSERT_THAT_EXPECTED( + pipe.Write(write_ptr, write_chunk_size, std::chrono::milliseconds(10)), + llvm::Succeeded()); +} + +TEST_F(PipeTest, ReadWithTimeout) { + Pipe pipe; + ASSERT_THAT_ERROR(pipe.CreateNew(false).ToError(), llvm::Succeeded()); + + char buf[100]; + // The pipe is initially empty. A polling read returns immediately. + ASSERT_THAT_EXPECTED(pipe.Read(buf, sizeof(buf), std::chrono::seconds(0)), + llvm::Failed()); + + // With a timeout, we should wait for at least this amount of time (but not + // too much). + auto start = std::chrono::steady_clock::now(); + ASSERT_THAT_EXPECTED( + pipe.Read(buf, sizeof(buf), std::chrono::milliseconds(200)), + llvm::Failed()); + auto dur = std::chrono::steady_clock::now() - start; + EXPECT_GT(dur, std::chrono::milliseconds(200)); + EXPECT_LT(dur, std::chrono::seconds(2)); + + // Write something into the pipe, and read it back. The blocking read call + // should return even though it hasn't filled the buffer. + llvm::StringRef hello_world("Hello world!"); + ASSERT_THAT_EXPECTED(pipe.Write(hello_world.data(), hello_world.size()), + llvm::HasValue(hello_world.size())); + ASSERT_THAT_EXPECTED(pipe.Read(buf, sizeof(buf)), + llvm::HasValue(hello_world.size())); + EXPECT_EQ(llvm::StringRef(buf, hello_world.size()), hello_world); + + // Now write something and try to read it in chunks. + memset(buf, 0, sizeof(buf)); + ASSERT_THAT_EXPECTED(pipe.Write(hello_world.data(), hello_world.size()), + llvm::HasValue(hello_world.size())); + ASSERT_THAT_EXPECTED(pipe.Read(buf, 4), llvm::HasValue(4)); + ASSERT_THAT_EXPECTED(pipe.Read(buf + 4, sizeof(buf) - 4), + llvm::HasValue(hello_world.size() - 4)); + EXPECT_EQ(llvm::StringRef(buf, hello_world.size()), hello_world); + + // A blocking read should wait until the data arrives. + memset(buf, 0, sizeof(buf)); + std::future> future_num_bytes = std::async( + std::launch::async, [&] { return pipe.Read(buf, sizeof(buf)); }); + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + ASSERT_THAT_EXPECTED(pipe.Write(hello_world.data(), hello_world.size()), + llvm::HasValue(hello_world.size())); + ASSERT_THAT_EXPECTED(future_num_bytes.get(), + llvm::HasValue(hello_world.size())); + EXPECT_EQ(llvm::StringRef(buf, hello_world.size()), hello_world); } diff --git a/lldb/unittests/Host/SocketTest.cpp b/lldb/unittests/Host/SocketTest.cpp index 3a356d11ba1a5..6b5efe5110a75 100644 --- a/lldb/unittests/Host/SocketTest.cpp +++ b/lldb/unittests/Host/SocketTest.cpp @@ -85,6 +85,43 @@ TEST_P(SocketTest, DomainListenConnectAccept) { std::unique_ptr socket_b_up; CreateDomainConnectedSockets(Path, &socket_a_up, &socket_b_up); } + +TEST_P(SocketTest, DomainMainLoopAccept) { + llvm::SmallString<64> Path; + std::error_code EC = + llvm::sys::fs::createUniqueDirectory("DomainListenConnectAccept", Path); + ASSERT_FALSE(EC); + llvm::sys::path::append(Path, "test"); + + // Skip the test if the $TMPDIR is too long to hold a domain socket. + if (Path.size() > 107u) + return; + + auto listen_socket_up = std::make_unique( + /*should_close=*/true, /*child_process_inherit=*/false); + Status error = listen_socket_up->Listen(Path, 5); + ASSERT_THAT_ERROR(error.ToError(), llvm::Succeeded()); + ASSERT_TRUE(listen_socket_up->IsValid()); + + MainLoop loop; + std::unique_ptr accepted_socket_up; + auto expected_handles = listen_socket_up->Accept( + loop, [&accepted_socket_up, &loop](std::unique_ptr sock_up) { + accepted_socket_up = std::move(sock_up); + loop.RequestTermination(); + }); + ASSERT_THAT_EXPECTED(expected_handles, llvm::Succeeded()); + + auto connect_socket_up = std::make_unique( + /*should_close=*/true, /*child_process_inherit=*/false); + ASSERT_THAT_ERROR(connect_socket_up->Connect(Path).ToError(), + llvm::Succeeded()); + ASSERT_TRUE(connect_socket_up->IsValid()); + + loop.Run(); + ASSERT_TRUE(accepted_socket_up); + ASSERT_TRUE(accepted_socket_up->IsValid()); +} #endif TEST_P(SocketTest, TCPListen0ConnectAccept) { @@ -109,9 +146,9 @@ TEST_P(SocketTest, TCPMainLoopAccept) { ASSERT_TRUE(listen_socket_up->IsValid()); MainLoop loop; - std::unique_ptr accepted_socket_up; + std::unique_ptr accepted_socket_up; auto expected_handles = listen_socket_up->Accept( - loop, [&accepted_socket_up, &loop](std::unique_ptr sock_up) { + loop, [&accepted_socket_up, &loop](std::unique_ptr sock_up) { accepted_socket_up = std::move(sock_up); loop.RequestTermination(); }); @@ -168,12 +205,30 @@ TEST_P(SocketTest, TCPListen0GetPort) { if (!HostSupportsIPv4()) return; llvm::Expected> sock = - Socket::TcpListen("10.10.12.3:0", false); + Socket::TcpListen("10.10.12.3:0", 5); ASSERT_THAT_EXPECTED(sock, llvm::Succeeded()); ASSERT_TRUE(sock.get()->IsValid()); EXPECT_NE(sock.get()->GetLocalPortNumber(), 0); } +TEST_P(SocketTest, TCPListen0GetListeningConnectionURI) { + if (!HostSupportsProtocol()) + return; + + std::string addr = llvm::formatv("[{0}]:0", GetParam().localhost_ip).str(); + llvm::Expected> sock = + Socket::TcpListen(addr, false); + ASSERT_THAT_EXPECTED(sock, llvm::Succeeded()); + ASSERT_TRUE(sock.get()->IsValid()); + + EXPECT_THAT( + sock.get()->GetListeningConnectionURI(), + testing::ElementsAre(llvm::formatv("connection://[{0}]:{1}", + GetParam().localhost_ip, + sock->get()->GetLocalPortNumber()) + .str())); +} + TEST_P(SocketTest, TCPGetConnectURI) { std::unique_ptr socket_a_up; std::unique_ptr socket_b_up; diff --git a/lldb/unittests/Protocol/CMakeLists.txt b/lldb/unittests/Protocol/CMakeLists.txt new file mode 100644 index 0000000000000..801662b0544d8 --- /dev/null +++ b/lldb/unittests/Protocol/CMakeLists.txt @@ -0,0 +1,12 @@ +add_lldb_unittest(ProtocolTests + ProtocolMCPTest.cpp + ProtocolMCPServerTest.cpp + + LINK_LIBS + lldbCore + lldbUtility + lldbHost + lldbPluginPlatformMacOSX + lldbPluginProtocolServerMCP + LLVMTestingSupport + ) diff --git a/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp b/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp new file mode 100644 index 0000000000000..b2dcc740b5efd --- /dev/null +++ b/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp @@ -0,0 +1,327 @@ +//===-- ProtocolServerMCPTest.cpp -----------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "Plugins/Platform/MacOSX/PlatformRemoteMacOSX.h" +#include "Plugins/Protocol/MCP/MCPError.h" +#include "Plugins/Protocol/MCP/ProtocolServerMCP.h" +#include "TestingSupport/Host/SocketTestUtilities.h" +#include "TestingSupport/SubsystemRAII.h" +#include "lldb/Core/ProtocolServer.h" +#include "lldb/Host/FileSystem.h" +#include "lldb/Host/HostInfo.h" +#include "lldb/Host/JSONTransport.h" +#include "lldb/Host/Socket.h" +#include "llvm/Testing/Support/Error.h" +#include "gtest/gtest.h" + +using namespace llvm; +using namespace lldb; +using namespace lldb_private; +using namespace lldb_private::mcp::protocol; + +namespace { +class TestProtocolServerMCP : public lldb_private::mcp::ProtocolServerMCP { +public: + using ProtocolServerMCP::AddNotificationHandler; + using ProtocolServerMCP::AddRequestHandler; + using ProtocolServerMCP::AddResourceProvider; + using ProtocolServerMCP::AddTool; + using ProtocolServerMCP::GetSocket; + using ProtocolServerMCP::ProtocolServerMCP; +}; + +class TestJSONTransport : public lldb_private::JSONRPCTransport { +public: + using JSONRPCTransport::JSONRPCTransport; + using JSONRPCTransport::ReadImpl; + using JSONRPCTransport::WriteImpl; +}; + +/// Test tool that returns it argument as text. +class TestTool : public mcp::Tool { +public: + using mcp::Tool::Tool; + + virtual llvm::Expected + Call(const ToolArguments &args) override { + std::string argument; + if (const json::Object *args_obj = + std::get(args).getAsObject()) { + if (const json::Value *s = args_obj->get("arguments")) { + argument = s->getAsString().value_or(""); + } + } + + mcp::protocol::TextResult text_result; + text_result.content.emplace_back(mcp::protocol::TextContent{{argument}}); + return text_result; + } +}; + +class TestResourceProvider : public mcp::ResourceProvider { + using mcp::ResourceProvider::ResourceProvider; + + virtual std::vector GetResources() const override { + std::vector resources; + + Resource resource; + resource.uri = "lldb://foo/bar"; + resource.name = "name"; + resource.description = "description"; + resource.mimeType = "application/json"; + + resources.push_back(resource); + return resources; + } + + virtual llvm::Expected + ReadResource(llvm::StringRef uri) const override { + if (uri != "lldb://foo/bar") + return llvm::make_error(uri.str()); + + ResourceContents contents; + contents.uri = "lldb://foo/bar"; + contents.mimeType = "application/json"; + contents.text = "foobar"; + + ResourceResult result; + result.contents.push_back(contents); + return result; + } +}; + +/// Test tool that returns an error. +class ErrorTool : public mcp::Tool { +public: + using mcp::Tool::Tool; + + virtual llvm::Expected + Call(const ToolArguments &args) override { + return llvm::createStringError("error"); + } +}; + +/// Test tool that fails but doesn't return an error. +class FailTool : public mcp::Tool { +public: + using mcp::Tool::Tool; + + virtual llvm::Expected + Call(const ToolArguments &args) override { + mcp::protocol::TextResult text_result; + text_result.content.emplace_back(mcp::protocol::TextContent{{"failed"}}); + text_result.isError = true; + return text_result; + } +}; + +class ProtocolServerMCPTest : public ::testing::Test { +public: + SubsystemRAII subsystems; + DebuggerSP m_debugger_sp; + + lldb::IOObjectSP m_io_sp; + std::unique_ptr m_transport_up; + std::unique_ptr m_server_up; + + static constexpr llvm::StringLiteral k_localhost = "localhost"; + + llvm::Error Write(llvm::StringRef message) { + return m_transport_up->WriteImpl(llvm::formatv("{0}\n", message).str()); + } + + llvm::Expected Read() { + return m_transport_up->ReadImpl(std::chrono::milliseconds(100)); + } + + void SetUp() { + // Create a debugger. + ArchSpec arch("arm64-apple-macosx-"); + Platform::SetHostPlatform( + PlatformRemoteMacOSX::CreateInstance(true, &arch)); + m_debugger_sp = Debugger::CreateInstance(); + + // Create & start the server. + ProtocolServer::Connection connection; + connection.protocol = Socket::SocketProtocol::ProtocolTcp; + connection.name = llvm::formatv("{0}:0", k_localhost).str(); + m_server_up = std::make_unique(); + m_server_up->AddTool(std::make_unique("test", "test tool")); + m_server_up->AddResourceProvider(std::make_unique()); + ASSERT_THAT_ERROR(m_server_up->Start(connection), llvm::Succeeded()); + + // Connect to the server over a TCP socket. + auto connect_socket_up = std::make_unique(true, false); + ASSERT_THAT_ERROR(connect_socket_up + ->Connect(llvm::formatv("{0}:{1}", k_localhost, + static_cast( + m_server_up->GetSocket()) + ->GetLocalPortNumber()) + .str()) + .ToError(), + llvm::Succeeded()); + + // Set up JSON transport for the client. + m_io_sp = std::move(connect_socket_up); + m_transport_up = std::make_unique(m_io_sp, m_io_sp); + } + + void TearDown() { + // Stop the server. + ASSERT_THAT_ERROR(m_server_up->Stop(), llvm::Succeeded()); + } +}; + +} // namespace + +TEST_F(ProtocolServerMCPTest, Intialization) { + llvm::StringLiteral request = + R"json({"method":"initialize","params":{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"lldb-unit","version":"0.1.0"}},"jsonrpc":"2.0","id":0})json"; + llvm::StringLiteral response = + R"json( {"id":0,"jsonrpc":"2.0","result":{"capabilities":{"resources":{"listChanged":false,"subscribe":false},"tools":{"listChanged":true}},"protocolVersion":"2024-11-05","serverInfo":{"name":"lldb-mcp","version":"0.1.0"}}})json"; + + ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); + + llvm::Expected response_str = Read(); + ASSERT_THAT_EXPECTED(response_str, llvm::Succeeded()); + + llvm::Expected response_json = json::parse(*response_str); + ASSERT_THAT_EXPECTED(response_json, llvm::Succeeded()); + + llvm::Expected expected_json = json::parse(response); + ASSERT_THAT_EXPECTED(expected_json, llvm::Succeeded()); + + EXPECT_EQ(*response_json, *expected_json); +} + +TEST_F(ProtocolServerMCPTest, ToolsList) { + llvm::StringLiteral request = + R"json({"method":"tools/list","params":{},"jsonrpc":"2.0","id":1})json"; + llvm::StringLiteral response = + R"json({"id":1,"jsonrpc":"2.0","result":{"tools":[{"description":"test tool","inputSchema":{"type":"object"},"name":"test"},{"description":"Run an lldb command.","inputSchema":{"properties":{"arguments":{"type":"string"},"debugger_id":{"type":"number"}},"required":["debugger_id"],"type":"object"},"name":"lldb_command"}]}})json"; + + ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); + + llvm::Expected response_str = Read(); + ASSERT_THAT_EXPECTED(response_str, llvm::Succeeded()); + + llvm::Expected response_json = json::parse(*response_str); + ASSERT_THAT_EXPECTED(response_json, llvm::Succeeded()); + + llvm::Expected expected_json = json::parse(response); + ASSERT_THAT_EXPECTED(expected_json, llvm::Succeeded()); + + EXPECT_EQ(*response_json, *expected_json); +} + +TEST_F(ProtocolServerMCPTest, ResourcesList) { + llvm::StringLiteral request = + R"json({"method":"resources/list","params":{},"jsonrpc":"2.0","id":2})json"; + llvm::StringLiteral response = + R"json({"id":2,"jsonrpc":"2.0","result":{"resources":[{"description":"description","mimeType":"application/json","name":"name","uri":"lldb://foo/bar"}]}})json"; + + ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); + + llvm::Expected response_str = Read(); + ASSERT_THAT_EXPECTED(response_str, llvm::Succeeded()); + + llvm::Expected response_json = json::parse(*response_str); + ASSERT_THAT_EXPECTED(response_json, llvm::Succeeded()); + + llvm::Expected expected_json = json::parse(response); + ASSERT_THAT_EXPECTED(expected_json, llvm::Succeeded()); + + EXPECT_EQ(*response_json, *expected_json); +} + +TEST_F(ProtocolServerMCPTest, ToolsCall) { + llvm::StringLiteral request = + R"json({"method":"tools/call","params":{"name":"test","arguments":{"arguments":"foo","debugger_id":0}},"jsonrpc":"2.0","id":11})json"; + llvm::StringLiteral response = + R"json({"id":11,"jsonrpc":"2.0","result":{"content":[{"text":"foo","type":"text"}],"isError":false}})json"; + + ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); + + llvm::Expected response_str = Read(); + ASSERT_THAT_EXPECTED(response_str, llvm::Succeeded()); + + llvm::Expected response_json = json::parse(*response_str); + ASSERT_THAT_EXPECTED(response_json, llvm::Succeeded()); + + llvm::Expected expected_json = json::parse(response); + ASSERT_THAT_EXPECTED(expected_json, llvm::Succeeded()); + + EXPECT_EQ(*response_json, *expected_json); +} + +TEST_F(ProtocolServerMCPTest, ToolsCallError) { + m_server_up->AddTool(std::make_unique("error", "error tool")); + + llvm::StringLiteral request = + R"json({"method":"tools/call","params":{"name":"error","arguments":{"arguments":"foo","debugger_id":0}},"jsonrpc":"2.0","id":11})json"; + llvm::StringLiteral response = + R"json({"error":{"code":-32603,"message":"error"},"id":11,"jsonrpc":"2.0"})json"; + + ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); + + llvm::Expected response_str = Read(); + ASSERT_THAT_EXPECTED(response_str, llvm::Succeeded()); + + llvm::Expected response_json = json::parse(*response_str); + ASSERT_THAT_EXPECTED(response_json, llvm::Succeeded()); + + llvm::Expected expected_json = json::parse(response); + ASSERT_THAT_EXPECTED(expected_json, llvm::Succeeded()); + + EXPECT_EQ(*response_json, *expected_json); +} + +TEST_F(ProtocolServerMCPTest, ToolsCallFail) { + m_server_up->AddTool(std::make_unique("fail", "fail tool")); + + llvm::StringLiteral request = + R"json({"method":"tools/call","params":{"name":"fail","arguments":{"arguments":"foo","debugger_id":0}},"jsonrpc":"2.0","id":11})json"; + llvm::StringLiteral response = + R"json({"id":11,"jsonrpc":"2.0","result":{"content":[{"text":"failed","type":"text"}],"isError":true}})json"; + + ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); + + llvm::Expected response_str = Read(); + ASSERT_THAT_EXPECTED(response_str, llvm::Succeeded()); + + llvm::Expected response_json = json::parse(*response_str); + ASSERT_THAT_EXPECTED(response_json, llvm::Succeeded()); + + llvm::Expected expected_json = json::parse(response); + ASSERT_THAT_EXPECTED(expected_json, llvm::Succeeded()); + + EXPECT_EQ(*response_json, *expected_json); +} + +TEST_F(ProtocolServerMCPTest, NotificationInitialized) { + bool handler_called = false; + std::condition_variable cv; + std::mutex mutex; + + m_server_up->AddNotificationHandler( + "notifications/initialized", + [&](const mcp::protocol::Notification ¬ification) { + { + std::lock_guard lock(mutex); + handler_called = true; + } + cv.notify_all(); + }); + llvm::StringLiteral request = + R"json({"method":"notifications/initialized","jsonrpc":"2.0"})json"; + + ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); + + std::unique_lock lock(mutex); + cv.wait(lock, [&] { return handler_called; }); +} diff --git a/lldb/unittests/Protocol/ProtocolMCPTest.cpp b/lldb/unittests/Protocol/ProtocolMCPTest.cpp new file mode 100644 index 0000000000000..ce8120cbfe9b9 --- /dev/null +++ b/lldb/unittests/Protocol/ProtocolMCPTest.cpp @@ -0,0 +1,330 @@ +//===-- ProtocolMCPTest.cpp -----------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "Plugins/Protocol/MCP/Protocol.h" +#include "TestingSupport/TestUtilities.h" +#include "llvm/Testing/Support/Error.h" +#include "gtest/gtest.h" + +using namespace lldb; +using namespace lldb_private; +using namespace lldb_private::mcp::protocol; + +TEST(ProtocolMCPTest, Request) { + Request request; + request.id = 1; + request.method = "foo"; + request.params = llvm::json::Object{{"key", "value"}}; + + llvm::Expected deserialized_request = roundtripJSON(request); + ASSERT_THAT_EXPECTED(deserialized_request, llvm::Succeeded()); + + EXPECT_EQ(request.id, deserialized_request->id); + EXPECT_EQ(request.method, deserialized_request->method); + EXPECT_EQ(request.params, deserialized_request->params); +} + +TEST(ProtocolMCPTest, Response) { + Response response; + response.id = 1; + response.result = llvm::json::Object{{"key", "value"}}; + + llvm::Expected deserialized_response = roundtripJSON(response); + ASSERT_THAT_EXPECTED(deserialized_response, llvm::Succeeded()); + + EXPECT_EQ(response.id, deserialized_response->id); + EXPECT_EQ(response.result, deserialized_response->result); +} + +TEST(ProtocolMCPTest, Notification) { + Notification notification; + notification.method = "notifyMethod"; + notification.params = llvm::json::Object{{"key", "value"}}; + + llvm::Expected deserialized_notification = + roundtripJSON(notification); + ASSERT_THAT_EXPECTED(deserialized_notification, llvm::Succeeded()); + + EXPECT_EQ(notification.method, deserialized_notification->method); + EXPECT_EQ(notification.params, deserialized_notification->params); +} + +TEST(ProtocolMCPTest, ToolCapability) { + ToolCapability tool_capability; + tool_capability.listChanged = true; + + llvm::Expected deserialized_tool_capability = + roundtripJSON(tool_capability); + ASSERT_THAT_EXPECTED(deserialized_tool_capability, llvm::Succeeded()); + + EXPECT_EQ(tool_capability.listChanged, + deserialized_tool_capability->listChanged); +} + +TEST(ProtocolMCPTest, Capabilities) { + ToolCapability tool_capability; + tool_capability.listChanged = true; + + Capabilities capabilities; + capabilities.tools = tool_capability; + + llvm::Expected deserialized_capabilities = + roundtripJSON(capabilities); + ASSERT_THAT_EXPECTED(deserialized_capabilities, llvm::Succeeded()); + + EXPECT_EQ(capabilities.tools.listChanged, + deserialized_capabilities->tools.listChanged); +} + +TEST(ProtocolMCPTest, TextContent) { + TextContent text_content; + text_content.text = "Sample text"; + + llvm::Expected deserialized_text_content = + roundtripJSON(text_content); + ASSERT_THAT_EXPECTED(deserialized_text_content, llvm::Succeeded()); + + EXPECT_EQ(text_content.text, deserialized_text_content->text); +} + +TEST(ProtocolMCPTest, TextResult) { + TextContent text_content1; + text_content1.text = "Text 1"; + + TextContent text_content2; + text_content2.text = "Text 2"; + + TextResult text_result; + text_result.content = {text_content1, text_content2}; + text_result.isError = true; + + llvm::Expected deserialized_text_result = + roundtripJSON(text_result); + ASSERT_THAT_EXPECTED(deserialized_text_result, llvm::Succeeded()); + + EXPECT_EQ(text_result.isError, deserialized_text_result->isError); + ASSERT_EQ(text_result.content.size(), + deserialized_text_result->content.size()); + EXPECT_EQ(text_result.content[0].text, + deserialized_text_result->content[0].text); + EXPECT_EQ(text_result.content[1].text, + deserialized_text_result->content[1].text); +} + +TEST(ProtocolMCPTest, ToolDefinition) { + ToolDefinition tool_definition; + tool_definition.name = "ToolName"; + tool_definition.description = "Tool Description"; + tool_definition.inputSchema = + llvm::json::Object{{"schemaKey", "schemaValue"}}; + + llvm::Expected deserialized_tool_definition = + roundtripJSON(tool_definition); + ASSERT_THAT_EXPECTED(deserialized_tool_definition, llvm::Succeeded()); + + EXPECT_EQ(tool_definition.name, deserialized_tool_definition->name); + EXPECT_EQ(tool_definition.description, + deserialized_tool_definition->description); + EXPECT_EQ(tool_definition.inputSchema, + deserialized_tool_definition->inputSchema); +} + +TEST(ProtocolMCPTest, MessageWithRequest) { + Request request; + request.id = 1; + request.method = "test_method"; + request.params = llvm::json::Object{{"param", "value"}}; + + Message message = request; + + llvm::Expected deserialized_message = roundtripJSON(message); + ASSERT_THAT_EXPECTED(deserialized_message, llvm::Succeeded()); + + ASSERT_TRUE(std::holds_alternative(*deserialized_message)); + const Request &deserialized_request = + std::get(*deserialized_message); + + EXPECT_EQ(request.id, deserialized_request.id); + EXPECT_EQ(request.method, deserialized_request.method); + EXPECT_EQ(request.params, deserialized_request.params); +} + +TEST(ProtocolMCPTest, MessageWithResponse) { + Response response; + response.id = 2; + response.result = llvm::json::Object{{"result", "success"}}; + + Message message = response; + + llvm::Expected deserialized_message = roundtripJSON(message); + ASSERT_THAT_EXPECTED(deserialized_message, llvm::Succeeded()); + + ASSERT_TRUE(std::holds_alternative(*deserialized_message)); + const Response &deserialized_response = + std::get(*deserialized_message); + + EXPECT_EQ(response.id, deserialized_response.id); + EXPECT_EQ(response.result, deserialized_response.result); +} + +TEST(ProtocolMCPTest, MessageWithNotification) { + Notification notification; + notification.method = "notification_method"; + notification.params = llvm::json::Object{{"notify", "data"}}; + + Message message = notification; + + llvm::Expected deserialized_message = roundtripJSON(message); + ASSERT_THAT_EXPECTED(deserialized_message, llvm::Succeeded()); + + ASSERT_TRUE(std::holds_alternative(*deserialized_message)); + const Notification &deserialized_notification = + std::get(*deserialized_message); + + EXPECT_EQ(notification.method, deserialized_notification.method); + EXPECT_EQ(notification.params, deserialized_notification.params); +} + +TEST(ProtocolMCPTest, MessageWithError) { + ErrorInfo error_info; + error_info.code = -32603; + error_info.message = "Internal error"; + + Error error; + error.id = 3; + error.error = error_info; + + Message message = error; + + llvm::Expected deserialized_message = roundtripJSON(message); + ASSERT_THAT_EXPECTED(deserialized_message, llvm::Succeeded()); + + ASSERT_TRUE(std::holds_alternative(*deserialized_message)); + const Error &deserialized_error = std::get(*deserialized_message); + + EXPECT_EQ(error.id, deserialized_error.id); + EXPECT_EQ(error.error.code, deserialized_error.error.code); + EXPECT_EQ(error.error.message, deserialized_error.error.message); +} + +TEST(ProtocolMCPTest, ResponseWithError) { + ErrorInfo error_info; + error_info.code = -32700; + error_info.message = "Parse error"; + + Response response; + response.id = 4; + response.error = error_info; + + llvm::Expected deserialized_response = roundtripJSON(response); + ASSERT_THAT_EXPECTED(deserialized_response, llvm::Succeeded()); + + EXPECT_EQ(response.id, deserialized_response->id); + EXPECT_FALSE(deserialized_response->result.has_value()); + ASSERT_TRUE(deserialized_response->error.has_value()); + EXPECT_EQ(response.error->code, deserialized_response->error->code); + EXPECT_EQ(response.error->message, deserialized_response->error->message); +} + +TEST(ProtocolMCPTest, Resource) { + Resource resource; + resource.uri = "resource://example/test"; + resource.name = "Test Resource"; + resource.description = "A test resource for unit testing"; + resource.mimeType = "text/plain"; + + llvm::Expected deserialized_resource = roundtripJSON(resource); + ASSERT_THAT_EXPECTED(deserialized_resource, llvm::Succeeded()); + + EXPECT_EQ(resource.uri, deserialized_resource->uri); + EXPECT_EQ(resource.name, deserialized_resource->name); + EXPECT_EQ(resource.description, deserialized_resource->description); + EXPECT_EQ(resource.mimeType, deserialized_resource->mimeType); +} + +TEST(ProtocolMCPTest, ResourceWithoutOptionals) { + Resource resource; + resource.uri = "resource://example/minimal"; + resource.name = "Minimal Resource"; + + llvm::Expected deserialized_resource = roundtripJSON(resource); + ASSERT_THAT_EXPECTED(deserialized_resource, llvm::Succeeded()); + + EXPECT_EQ(resource.uri, deserialized_resource->uri); + EXPECT_EQ(resource.name, deserialized_resource->name); + EXPECT_TRUE(deserialized_resource->description.empty()); + EXPECT_TRUE(deserialized_resource->mimeType.empty()); +} + +TEST(ProtocolMCPTest, ResourceContents) { + ResourceContents contents; + contents.uri = "resource://example/content"; + contents.text = "This is the content of the resource"; + contents.mimeType = "text/plain"; + + llvm::Expected deserialized_contents = + roundtripJSON(contents); + ASSERT_THAT_EXPECTED(deserialized_contents, llvm::Succeeded()); + + EXPECT_EQ(contents.uri, deserialized_contents->uri); + EXPECT_EQ(contents.text, deserialized_contents->text); + EXPECT_EQ(contents.mimeType, deserialized_contents->mimeType); +} + +TEST(ProtocolMCPTest, ResourceContentsWithoutMimeType) { + ResourceContents contents; + contents.uri = "resource://example/content-no-mime"; + contents.text = "Content without mime type specified"; + + llvm::Expected deserialized_contents = + roundtripJSON(contents); + ASSERT_THAT_EXPECTED(deserialized_contents, llvm::Succeeded()); + + EXPECT_EQ(contents.uri, deserialized_contents->uri); + EXPECT_EQ(contents.text, deserialized_contents->text); + EXPECT_TRUE(deserialized_contents->mimeType.empty()); +} + +TEST(ProtocolMCPTest, ResourceResult) { + ResourceContents contents1; + contents1.uri = "resource://example/content1"; + contents1.text = "First resource content"; + contents1.mimeType = "text/plain"; + + ResourceContents contents2; + contents2.uri = "resource://example/content2"; + contents2.text = "Second resource content"; + contents2.mimeType = "application/json"; + + ResourceResult result; + result.contents = {contents1, contents2}; + + llvm::Expected deserialized_result = roundtripJSON(result); + ASSERT_THAT_EXPECTED(deserialized_result, llvm::Succeeded()); + + ASSERT_EQ(result.contents.size(), deserialized_result->contents.size()); + + EXPECT_EQ(result.contents[0].uri, deserialized_result->contents[0].uri); + EXPECT_EQ(result.contents[0].text, deserialized_result->contents[0].text); + EXPECT_EQ(result.contents[0].mimeType, + deserialized_result->contents[0].mimeType); + + EXPECT_EQ(result.contents[1].uri, deserialized_result->contents[1].uri); + EXPECT_EQ(result.contents[1].text, deserialized_result->contents[1].text); + EXPECT_EQ(result.contents[1].mimeType, + deserialized_result->contents[1].mimeType); +} + +TEST(ProtocolMCPTest, ResourceResultEmpty) { + ResourceResult result; + + llvm::Expected deserialized_result = roundtripJSON(result); + ASSERT_THAT_EXPECTED(deserialized_result, llvm::Succeeded()); + + EXPECT_TRUE(deserialized_result->contents.empty()); +} diff --git a/lldb/unittests/TestingSupport/Host/PipeTestUtilities.h b/lldb/unittests/TestingSupport/Host/PipeTestUtilities.h new file mode 100644 index 0000000000000..87a85ad77e65d --- /dev/null +++ b/lldb/unittests/TestingSupport/Host/PipeTestUtilities.h @@ -0,0 +1,28 @@ +//===-- PipeTestUtilities.cpp ---------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLDB_UNITTESTS_TESTINGSUPPORT_PIPETESTUTILITIES_H +#define LLDB_UNITTESTS_TESTINGSUPPORT_PIPETESTUTILITIES_H + +#include "lldb/Host/Pipe.h" +#include "llvm/Testing/Support/Error.h" +#include "gtest/gtest.h" + +/// A base class for tests that need a pair of pipes for communication. +class PipePairTest : public testing::Test { +protected: + lldb_private::Pipe input; + lldb_private::Pipe output; + + void SetUp() override { + ASSERT_THAT_ERROR(input.CreateNew(false).ToError(), llvm::Succeeded()); + ASSERT_THAT_ERROR(output.CreateNew(false).ToError(), llvm::Succeeded()); + } +}; + +#endif diff --git a/lldb/unittests/TestingSupport/TestUtilities.h b/lldb/unittests/TestingSupport/TestUtilities.h index 7d040d64db8d8..a8bdda6ad33ae 100644 --- a/lldb/unittests/TestingSupport/TestUtilities.h +++ b/lldb/unittests/TestingSupport/TestUtilities.h @@ -56,6 +56,15 @@ class TestFile { std::string Buffer; }; + +template static llvm::Expected roundtripJSON(const T &input) { + llvm::json::Value value = toJSON(input); + llvm::json::Path::Root root; + T output; + if (!fromJSON(value, output, root)) + return root.getError(); + return output; +} } // namespace lldb_private #endif