From c991aed05e316b144263f6967759aaccfbec5ce5 Mon Sep 17 00:00:00 2001 From: Phosphorus-M Date: Sun, 26 Oct 2025 12:00:32 -0300 Subject: [PATCH] =?UTF-8?q?feat:=20add=20Socket.IO=20support=20with=20sock?= =?UTF-8?q?etioxide=20integration=20=20=F0=9F=98=AD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Cargo.lock | 286 ++++++++++---- Cargo.toml | 2 + examples/socketio-server/.env | 5 + examples/socketio-server/Cargo.toml | 11 + examples/socketio-server/compose.yaml | 20 + examples/socketio-server/config/config.toml | 9 + .../migrations/20251024204317_tasks.sql | 6 + examples/socketio-server/src/database.rs | 42 ++ examples/socketio-server/src/main.rs | 134 +++++++ sword-macros/src/lib.rs | 145 +++++++ sword-macros/src/websocket/expand.rs | 371 ++++++++++++++++++ sword-macros/src/websocket/mod.rs | 4 + sword-macros/src/websocket/parsing.rs | 52 +++ sword/Cargo.toml | 5 + sword/src/core/application/app.rs | 60 ++- sword/src/core/application/builder.rs | 84 +++- sword/src/lib.rs | 7 +- sword/src/prelude.rs | 3 + sword/src/web/mod.rs | 4 + sword/src/web/websocket/handler.rs | 39 ++ sword/src/web/websocket/mod.rs | 37 ++ sword/src/web/websocket/types.rs | 8 + 22 files changed, 1243 insertions(+), 91 deletions(-) create mode 100644 examples/socketio-server/.env create mode 100644 examples/socketio-server/Cargo.toml create mode 100644 examples/socketio-server/compose.yaml create mode 100644 examples/socketio-server/config/config.toml create mode 100644 examples/socketio-server/config/migrations/20251024204317_tasks.sql create mode 100644 examples/socketio-server/src/database.rs create mode 100644 examples/socketio-server/src/main.rs create mode 100644 sword-macros/src/websocket/expand.rs create mode 100644 sword-macros/src/websocket/mod.rs create mode 100644 sword-macros/src/websocket/parsing.rs create mode 100644 sword/src/web/websocket/handler.rs create mode 100644 sword/src/web/websocket/mod.rs create mode 100644 sword/src/web/websocket/types.rs diff --git a/Cargo.lock b/Cargo.lock index e4dc21b..1c8cc83 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,21 +2,6 @@ # It is not intended for manual editing. version = 4 -[[package]] -name = "addr2line" -version = "0.25.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b5d307320b3181d6d7954e663bd7c774a838b8220fe0593c86d9fb09f498b4b" -dependencies = [ - "gimli", -] - -[[package]] -name = "adler2" -version = "2.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa" - [[package]] name = "aead" version = "0.5.2" @@ -93,6 +78,15 @@ version = "1.0.100" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a23eb6b1614318a8071c9b2521f36b424b2c83db5eb3a0fead4a6c0809af6e61" +[[package]] +name = "arbitrary" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3d036a3c4ab069c7b410a2ce876bd74808d2d0888a82667669f8e783a898bf1" +dependencies = [ + "derive_arbitrary", +] + [[package]] name = "arrayvec" version = "0.7.6" @@ -246,21 +240,6 @@ dependencies = [ "serde_json", ] -[[package]] -name = "backtrace" -version = "0.3.76" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bb531853791a215d7c62a30daf0dde835f381ab5de4589cfe7c649d2cbe92bd6" -dependencies = [ - "addr2line", - "cfg-if", - "libc", - "miniz_oxide", - "object", - "rustc-demangle", - "windows-link", -] - [[package]] name = "base64" version = "0.22.1" @@ -385,6 +364,9 @@ name = "bytes" version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a" +dependencies = [ + "serde", +] [[package]] name = "bytesize" @@ -674,6 +656,17 @@ dependencies = [ "powerfmt", ] +[[package]] +name = "derive_arbitrary" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e567bd82dcff979e4b03460c307b3cdc9e96fde3d73bed1496d2bc75d9dd62a" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.106", +] + [[package]] name = "diff" version = "0.1.13" @@ -744,7 +737,7 @@ dependencies = [ "subsecond", "thiserror", "tracing", - "tungstenite", + "tungstenite 0.27.0", "warnings", ] @@ -816,6 +809,45 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "engineioxide" +version = "0.17.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88a4ef9fd57bc7e9fbe59550d3cba88536fbe47fba05ab33088edc4b09d3267a" +dependencies = [ + "base64", + "bytes", + "engineioxide-core", + "futures-core", + "futures-util", + "http", + "http-body", + "http-body-util", + "hyper", + "hyper-util", + "pin-project-lite", + "serde", + "serde_json", + "smallvec", + "thiserror", + "tokio", + "tokio-tungstenite", + "tower-layer", + "tower-service", +] + +[[package]] +name = "engineioxide-core" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "04e5d58eb7374df380cbb53ef65f9c35f544c9c217528adb1458c8df05978475" +dependencies = [ + "base64", + "bytes", + "rand 0.9.2", + "serde", +] + [[package]] name = "equivalent" version = "1.0.2" @@ -1067,12 +1099,6 @@ dependencies = [ "polyval", ] -[[package]] -name = "gimli" -version = "0.32.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e629b9b98ef3dd8afe6ca2bd0f89306cec16d43d907889945bc5d6687f2f13c7" - [[package]] name = "hashbrown" version = "0.12.3" @@ -1423,17 +1449,6 @@ dependencies = [ "rustversion", ] -[[package]] -name = "io-uring" -version = "0.7.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "046fa2d4d00aea763528b4950358d0ead425372445dc8ff86312b3c69ff7727b" -dependencies = [ - "bitflags", - "cfg-if", - "libc", -] - [[package]] name = "itoa" version = "1.0.15" @@ -1591,15 +1606,6 @@ version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" -[[package]] -name = "miniz_oxide" -version = "0.8.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1fa76a2c86f704bdb222d66965fb3d63269ce38518b83cb0575fca855ebb6316" -dependencies = [ - "adler2", -] - [[package]] name = "mio" version = "1.0.4" @@ -1690,15 +1696,6 @@ dependencies = [ "libm", ] -[[package]] -name = "object" -version = "0.37.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff76201f031d8863c38aa7f905eca4f53abbfa15f609db4277d44cd8938f33fe" -dependencies = [ - "memchr", -] - [[package]] name = "once_cell" version = "1.21.3" @@ -1740,6 +1737,12 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "paste" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" + [[package]] name = "pem-rfc7468" version = "0.7.0" @@ -2115,6 +2118,29 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "rmp" +version = "0.8.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "228ed7c16fa39782c3b3468e974aec2795e9089153cd08ee2e9aefb3613334c4" +dependencies = [ + "byteorder", + "num-traits", + "paste", +] + +[[package]] +name = "rmpv" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "58450723cd9ee93273ce44a20b6ec4efe17f8ed2e3631474387bfdecf18bb2a9" +dependencies = [ + "num-traits", + "rmp", + "serde", + "serde_bytes", +] + [[package]] name = "rsa" version = "0.9.8" @@ -2166,12 +2192,6 @@ dependencies = [ "serde_json", ] -[[package]] -name = "rustc-demangle" -version = "0.1.26" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "56f7d92ca342cea22a06f2121d944b4fd82af56988c270852495420f961d4ace" - [[package]] name = "rustc-hash" version = "2.1.1" @@ -2225,6 +2245,16 @@ dependencies = [ "serde_derive", ] +[[package]] +name = "serde_bytes" +version = "0.11.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5d440709e79d88e51ac01c4b72fc6cb7314017bb7da9eeff678aa94c10e3ea8" +dependencies = [ + "serde", + "serde_core", +] + [[package]] name = "serde_core" version = "1.0.228" @@ -2378,6 +2408,69 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "socketio-server" +version = "0.1.0" +dependencies = [ + "dotenv", + "serde", + "sqlx", + "sword", + "sword-macros", +] + +[[package]] +name = "socketioxide" +version = "0.17.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "476190583b592f1e3d55584269600f2d3c4f18af36adad03c41c27e82dcb6bd5" +dependencies = [ + "bytes", + "engineioxide", + "futures-core", + "futures-util", + "http", + "http-body", + "hyper", + "matchit", + "pin-project-lite", + "serde", + "socketioxide-core", + "socketioxide-parser-common", + "thiserror", + "tokio", + "tower-layer", + "tower-service", +] + +[[package]] +name = "socketioxide-core" +version = "0.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b07b95089a961994921d23dd6e70792a06f5daa250b5ec8919f6f9de371d2cc5" +dependencies = [ + "arbitrary", + "bytes", + "engineioxide-core", + "futures-core", + "serde", + "smallvec", + "thiserror", +] + +[[package]] +name = "socketioxide-parser-common" +version = "0.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fe3b57122bf9c17fe8c2f364e1d307983068396cfb1b0407ec897de411f8033" +dependencies = [ + "bytes", + "itoa", + "serde", + "serde_json", + "socketioxide-core", +] + [[package]] name = "spin" version = "0.9.8" @@ -2438,6 +2531,7 @@ dependencies = [ "sha2", "smallvec", "thiserror", + "time", "tokio", "tokio-stream", "tracing", @@ -2520,6 +2614,7 @@ dependencies = [ "sqlx-core", "stringprep", "thiserror", + "time", "tracing", "whoami", ] @@ -2557,6 +2652,7 @@ dependencies = [ "sqlx-core", "stringprep", "thiserror", + "time", "tracing", "whoami", ] @@ -2581,6 +2677,7 @@ dependencies = [ "serde_urlencoded", "sqlx-core", "thiserror", + "time", "tracing", "url", ] @@ -2663,10 +2760,12 @@ dependencies = [ "http-body-util", "inventory", "regex-lite", + "rmpv", "serde", "serde_json", "serde_path_to_error", "serde_urlencoded", + "socketioxide", "subsecond", "sword-macros", "thiserror", @@ -2675,6 +2774,7 @@ dependencies = [ "tower", "tower-cookies", "tower-http", + "tracing", "validator", ] @@ -2825,29 +2925,26 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.47.1" +version = "1.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89e49afdadebb872d3145a5638b59eb0691ea23e46ca484037cfab3b76b95038" +checksum = "ff360e02eab121e0bc37a2d3b4d4dc622e6eda3a8e5253d5435ecf5bd4c68408" dependencies = [ - "backtrace", "bytes", - "io-uring", "libc", "mio", "parking_lot", "pin-project-lite", "signal-hook-registry", - "slab", "socket2", "tokio-macros", - "windows-sys 0.59.0", + "windows-sys 0.61.1", ] [[package]] name = "tokio-macros" -version = "2.5.0" +version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e06d43f1345a3bcd39f6a56dbb7dcab2ba47e68e8ac134855e7e2bdbaf8cab8" +checksum = "af407857209536a95c8e56f8231ef2c2e2aff839b22e07a1ffcbc617e9db9fa5" dependencies = [ "proc-macro2", "quote", @@ -2865,6 +2962,18 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-tungstenite" +version = "0.26.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a9daff607c6d2bf6c16fd681ccb7eecc83e4e2cdc1ca067ffaadfca5de7f084" +dependencies = [ + "futures-util", + "log", + "tokio", + "tungstenite 0.26.2", +] + [[package]] name = "toml" version = "0.9.7" @@ -3026,6 +3135,23 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" +[[package]] +name = "tungstenite" +version = "0.26.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4793cb5e56680ecbb1d843515b23b6de9a75eb04b66643e256a396d43be33c13" +dependencies = [ + "bytes", + "data-encoding", + "http", + "httparse", + "log", + "rand 0.9.2", + "sha1", + "thiserror", + "utf-8", +] + [[package]] name = "tungstenite" version = "0.27.0" diff --git a/Cargo.toml b/Cargo.toml index ad7fcd1..6ec2f36 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,3 +25,5 @@ async-trait = "0.1.83" tower-http = { version = "0.6.6", features = ["limit", "cors", "trace"] } tracing = "0.1.41" tracing-subscriber = { version = "0.3.19", features = ["env-filter"] } + +socketioxide = { version = "0.18.0", features = ["axum-websockets"] } \ No newline at end of file diff --git a/examples/socketio-server/.env b/examples/socketio-server/.env new file mode 100644 index 0000000..c9dde34 --- /dev/null +++ b/examples/socketio-server/.env @@ -0,0 +1,5 @@ +POSTGRES_USER=user +POSTGRES_PASSWORD=password +POSTGRES_DB=postgres_db +POSTGRES_DATABASE_URL="postgres://user:password@localhost:5432/postgres_db" +DATABASE_URL="postgres://user:password@localhost:5432/postgres_db" \ No newline at end of file diff --git a/examples/socketio-server/Cargo.toml b/examples/socketio-server/Cargo.toml new file mode 100644 index 0000000..d443512 --- /dev/null +++ b/examples/socketio-server/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "socketio-server" +version = "0.1.0" +edition = "2021" + +[dependencies] +sword = { workspace = true, features = ["websocket"] } +sword-macros = { workspace = true } +sqlx = { version = "0.8.6", features = ["postgres", "runtime-tokio", "time"] } +serde = { workspace = true } +dotenv = "0.15.0" diff --git a/examples/socketio-server/compose.yaml b/examples/socketio-server/compose.yaml new file mode 100644 index 0000000..422391a --- /dev/null +++ b/examples/socketio-server/compose.yaml @@ -0,0 +1,20 @@ +version: '3.8' + +services: + postgres: + image: postgres:15 + env_file: + - "./.env" + container_name: sword_postgres_example + restart: unless-stopped + environment: + POSTGRES_USER: ${POSTGRES_USER} + POSTGRES_PASSWORD: ${POSTGRES_PASSWORD} + POSTGRES_DB: ${POSTGRES_DB} + ports: + - "5432:5432" + volumes: + - postgres_data:/var/lib/postgresql/data + +volumes: + postgres_data: \ No newline at end of file diff --git a/examples/socketio-server/config/config.toml b/examples/socketio-server/config/config.toml new file mode 100644 index 0000000..7513ee4 --- /dev/null +++ b/examples/socketio-server/config/config.toml @@ -0,0 +1,9 @@ +[application] +host = "0.0.0.0" +port = 8080 +body_limit = "10MB" +name = "Dependency Injection Example" + +[db-config] +uri = "${POSTGRES_DATABASE_URL}" +migrations_path = "config/migrations" diff --git a/examples/socketio-server/config/migrations/20251024204317_tasks.sql b/examples/socketio-server/config/migrations/20251024204317_tasks.sql new file mode 100644 index 0000000..26d7805 --- /dev/null +++ b/examples/socketio-server/config/migrations/20251024204317_tasks.sql @@ -0,0 +1,6 @@ +-- Add migration script here + +CREATE TABLE tasks( + id INT PRIMARY KEY, + title TEXT NOT NULL +) \ No newline at end of file diff --git a/examples/socketio-server/src/database.rs b/examples/socketio-server/src/database.rs new file mode 100644 index 0000000..846257f --- /dev/null +++ b/examples/socketio-server/src/database.rs @@ -0,0 +1,42 @@ +use std::{path::Path, sync::Arc}; + +use serde::Deserialize; +use sqlx::{migrate::Migrator, PgPool}; +use sword::prelude::*; + +#[derive(Clone, Deserialize)] +#[config(key = "db-config")] +pub struct DatabaseConfig { + uri: String, + migrations_path: String, +} + +#[injectable(provider)] +pub struct Database { + pool: Arc, +} + +impl Database { + pub async fn new(db_conf: DatabaseConfig) -> Self { + let pool = PgPool::connect(&db_conf.uri) + .await + .expect("Failed to create Postgres connection pool"); + + let migrator = Migrator::new(Path::new(&db_conf.migrations_path)) + .await + .unwrap(); + + migrator + .run(&pool) + .await + .expect("Failed to run database migrations"); + + Self { + pool: Arc::new(pool), + } + } + + pub fn get_pool(&self) -> &PgPool { + &self.pool + } +} diff --git a/examples/socketio-server/src/main.rs b/examples/socketio-server/src/main.rs new file mode 100644 index 0000000..fbcc68b --- /dev/null +++ b/examples/socketio-server/src/main.rs @@ -0,0 +1,134 @@ +use dotenv::dotenv; +use std::sync::Arc; +use sword::prelude::*; +use sword_macros::{ + on_connection, on_disconnect, on_fallback, subscribe_message, web_socket, + web_socket_gateway, +}; + +use crate::database::{Database, DatabaseConfig}; +mod database; + +#[controller("/ohno")] +struct AppController {} + +#[routes] +impl AppController { + #[get("/test")] + async fn get_data(&self, _req: Request) -> HttpResponse { + let data = vec![ + "This is a basic web server", + "It serves static data", + "You can extend it with more routes", + ]; + + HttpResponse::Ok().data(data) + } +} + +#[web_socket_gateway] +struct SocketController { + db: Arc, +} + +#[web_socket_gateway] +struct OtherSocketController {} + +#[web_socket("/other_socket")] +impl OtherSocketController { + #[on_connection] + async fn on_connect(&self, _socket: SocketRef) { + println!("New client connected to OtherSocketController"); + } +} + +#[web_socket("/socket")] +impl SocketController { + #[on_connection] + async fn on_connect(&self, _socket: SocketRef) { + println!("New client connected"); + } + + #[subscribe_message("message")] + async fn on_message(&self, _socket: SocketRef, Data(_data): Data) { + println!("New message received"); + + let now = sqlx::query("SELECT NOW() as now") + .fetch_one(self.db.get_pool()) + .await + .expect("Oh no"); + + println!("Database time: {:?}", now); + } + + #[subscribe_message("message2")] + async fn other_message(&self, _socket: SocketRef, Data(_data): Data) { + println!("Other message received"); + } + + #[subscribe_message("message-with-ack")] + async fn message_with_ack( + &self, + Event(_event): Event, + Data(_data): Data, + ack: AckSender, + ) { + println!("Message with ack received"); + let response = Value::from("Acknowledged!"); + ack.send(&response).ok(); + } + + #[subscribe_message("message-with-event")] + async fn message_with_event( + &self, + Event(event): Event, + Data(data): Data, + ) { + println!("Message with event '{}' and data: {:?}", event, data); + } + + #[subscribe_message("another-message")] + async fn and_another_one_message(&self, _socket: SocketRef, ack: AckSender) { + println!("Another message received"); + + ack.send("response for another-message").ok(); + } + + #[subscribe_message("just-another-message")] + async fn just_another_message(&self) { + println!("Message with just-another-message received"); + } + + #[on_disconnect] + async fn on_disconnect(&self, _socket: SocketRef) { + println!("Socket disconnected"); + } + + #[on_fallback] + async fn on_fallback(&self, Event(event): Event, Data(data): Data) { + println!( + "Fallback handler invoked for event: {} with data: {:?}", + event, data + ); + } +} + +#[sword::main] +async fn main() { + dotenv().ok(); + + let app = Application::builder(); + let db_config = app.config::().unwrap(); + let db = Database::new(db_config).await; + + let container = DependencyContainer::builder().register_provider(db).build(); + + let app = Application::builder() + .with_dependency_container(container) + .with_socket::() + .with_socket::() + .with_controller::() + .build(); + + app.run().await; +} diff --git a/sword-macros/src/lib.rs b/sword-macros/src/lib.rs index f90d464..7f0e7b7 100644 --- a/sword-macros/src/lib.rs +++ b/sword-macros/src/lib.rs @@ -28,6 +28,8 @@ mod middlewares; mod injectable; +mod websocket; + /// Defines a handler for HTTP GET requests. /// This macro should be used inside an `impl` block of a struct annotated with the `#[controller]` macro. /// @@ -643,3 +645,146 @@ pub fn main(_args: TokenStream, item: TokenStream) -> TokenStream { output.into() } + +/// Marks a struct as a WebSocket gateway controller. +/// This macro should be used in combination with the `#[web_socket]` macro for handler implementation. +/// +/// ### Usage +/// ```rust,ignore +/// #[web_socket_gateway] +/// struct ChatSocket; +/// +/// #[web_socket("/chat")] +/// impl ChatSocket { +/// #[on_connection] +/// async fn on_connect(&self, socket: SocketRef) { +/// println!("Client connected"); +/// } +/// } +/// ``` +#[proc_macro_attribute] +pub fn web_socket_gateway(attr: TokenStream, item: TokenStream) -> TokenStream { + websocket::expand_websocket_gateway(attr, item) +} + +/// Defines WebSocket handlers for a struct. +/// This macro should be used inside an `impl` block of a struct annotated with the `#[web_socket_gateway]` macro. +/// +/// ### Parameters +/// - `path`: The path for the WebSocket endpoint, e.g., `"/socket"` +/// +/// ### Usage +/// ```rust,ignore +/// #[web_socket_gateway] +/// struct SocketController; +/// +/// #[web_socket("/socket")] +/// impl SocketController { +/// #[on_connection] +/// async fn on_connect(&self, socket: SocketRef) { +/// println!("Client connected"); +/// } +/// +/// #[subscribe_message("message")] +/// async fn on_message(&self, socket: SocketRef, Data(msg): Data) { +/// println!("Received: {}", msg); +/// } +/// +/// #[on_disconnect] +/// async fn on_disconnect(&self, socket: WebSocket) { +/// println!("Client disconnected"); +/// } +/// } +/// ``` +#[proc_macro_attribute] +pub fn web_socket(attr: TokenStream, item: TokenStream) -> TokenStream { + websocket::expand_websocket(attr, item) +} + +/// Marks a method as a WebSocket connection handler. +/// This method will be called when a client establishes a WebSocket connection. +/// +/// ### Parameters +/// The handler receives a `SocketRef` parameter for interacting with the connected client. +/// +/// ### Usage +/// ```rust,ignore +/// #[on_connection] +/// async fn on_connect(&self, socket: SocketRef) { +/// println!("Client connected: {}", socket.id); +/// } +/// ``` +#[proc_macro_attribute] +pub fn on_connection(attr: TokenStream, item: TokenStream) -> TokenStream { + let _ = attr; + item +} + +/// Marks a method as a WebSocket disconnection handler. +/// This method will be called when a client disconnects from the WebSocket. +/// +/// ### Parameters +/// The handler receives a `WebSocket` parameter with the disconnected client's information. +/// +/// ### Usage +/// ```rust,ignore +/// #[on_disconnect] +/// async fn on_disconnect(&self, socket: WebSocket) { +/// println!("Client disconnected: {}", socket.id()); +/// } +/// ``` +#[proc_macro_attribute] +pub fn on_disconnect(attr: TokenStream, item: TokenStream) -> TokenStream { + let _ = attr; + item +} + +/// Marks a method as a WebSocket message handler. +/// This method will be called when the client emits an event with the specified message type. +/// +/// ### Parameters +/// - `message_type`: The name of the event to handle, e.g., `"message"` or `"*"` for any event +/// +/// ### Parameters in handler +/// - `socket: SocketRef` - The connected client's socket +/// - `Data(data): Data` - The message payload deserialized to type T +/// - `ack: AckSender` (optional) - For sending acknowledgments back to the client +/// +/// ### Usage +/// ```rust,ignore +/// #[subscribe_message("message")] +/// async fn on_message(&self, socket: SocketRef, Data(msg): Data) { +/// println!("Received: {}", msg); +/// } +/// +/// #[subscribe_message("request")] +/// async fn on_request(&self, Data(req): Data, ack: AckSender) { +/// ack.send("response").ok(); +/// } +/// ``` +#[proc_macro_attribute] +pub fn subscribe_message(attr: TokenStream, item: TokenStream) -> TokenStream { + let _ = attr; + item +} + +/// Marks a method as a WebSocket fallback handler. +/// This method will be called for any event that doesn't match a specific `#[subscribe_message]` handler. +/// It's useful for debugging or handling dynamic events. +/// +/// ### Parameters in handler +/// - `Event(name): Event` - The event name +/// - `Data(data): Data` - The message payload +/// +/// ### Usage +/// ```rust,ignore +/// #[on_fallback] +/// async fn on_fallback(&self, Event(event): Event, Data(data): Data) { +/// println!("Unhandled event: {} with data: {:?}", event, data); +/// } +/// ``` +#[proc_macro_attribute] +pub fn on_fallback(attr: TokenStream, item: TokenStream) -> TokenStream { + let _ = attr; + item +} diff --git a/sword-macros/src/websocket/expand.rs b/sword-macros/src/websocket/expand.rs new file mode 100644 index 0000000..9b38963 --- /dev/null +++ b/sword-macros/src/websocket/expand.rs @@ -0,0 +1,371 @@ +//! WebSocket macro expansion logic + +use proc_macro::TokenStream; +use quote::quote; +use syn::{ImplItem, ItemImpl, ItemStruct, parse_macro_input}; + +use super::parsing::{HandlerType, WebSocketPath, get_handler_type}; + +/// Expands the `#[web_socket_gateway]` macro +pub fn expand_websocket_gateway( + _attr: TokenStream, + item: TokenStream, +) -> TokenStream { + let input = parse_macro_input!(item as ItemStruct); + let name = &input.ident; + let vis = &input.vis; + let fields = &input.fields; + + // Extract field names and types for Build implementation + let field_inits = if let syn::Fields::Named(named_fields) = fields { + named_fields.named.iter().map(|field| { + let field_name = &field.ident; + let field_type = &field.ty; + quote! { + #field_name: <#field_type as ::sword::core::FromStateArc>::from_state_arc(state) + .map_err(|_| ::sword::core::DependencyInjectionError::DependencyNotFound { + type_name: stringify!(#field_type).to_string(), + })? + } + }).collect::>() + } else { + vec![] + }; + + let expanded = quote! { + #[derive(Clone)] + #vis struct #name #fields + + impl #name { + pub fn router(state: ::sword::core::State) -> ::sword::__internal::AxumRouter { + ::sword::__internal::AxumRouter::new().with_state(state) + } + } + + impl ::sword::web::websocket::WebSocketGateway for #name { + fn router(state: ::sword::core::State) -> ::sword::__internal::AxumRouter { + Self::router(state) + } + } + + impl ::sword::core::Build for #name { + type Error = ::sword::core::DependencyInjectionError; + + fn build(state: &::sword::core::State) -> Result { + Ok(Self { + #(#field_inits),* + }) + } + } + }; + + expanded.into() +} + +// Helper struct to store handler parameter information +struct HandlerParams { + has_socket: bool, + has_data: bool, + has_event: bool, + has_ack: bool, +} + +// Helper function to analyze handler parameters +fn analyze_handler_params(sig: &syn::Signature) -> HandlerParams { + let mut params = HandlerParams { + has_socket: false, + has_data: false, + has_event: false, + has_ack: false, + }; + + for arg in &sig.inputs { + if let syn::FnArg::Typed(pat_type) = arg { + let ty_str = quote::quote!(#pat_type.ty).to_string(); + if ty_str.contains("SocketRef") { + params.has_socket = true; + } + if ty_str.contains("Data") { + params.has_data = true; + } + if ty_str.contains("Event") { + params.has_event = true; + } + if ty_str.contains("AckSender") { + params.has_ack = true; + } + } + } + + params +} + +/// Expands the `#[web_socket]` macro +pub fn expand_websocket(attr: TokenStream, item: TokenStream) -> TokenStream { + let path_struct = parse_macro_input!(attr as WebSocketPath); + let path = path_struct.path; + let input = parse_macro_input!(item as ItemImpl); + + let self_ty = &input.self_ty; + + // Extract all handler methods from the impl block with their parameter info + let mut on_connect_handler: Option<(syn::Ident, HandlerParams)> = None; + let mut message_handlers: Vec<(String, syn::Ident, HandlerParams)> = Vec::new(); + let mut on_disconnect_handler: Option<(syn::Ident, HandlerParams)> = None; + let mut on_fallback_handler: Option<(syn::Ident, HandlerParams)> = None; + + for item in &input.items { + if let ImplItem::Fn(method) = item { + if let Some((handler_type, message_type)) = + get_handler_type(&method.attrs) + { + let method_name = method.sig.ident.clone(); + let params = analyze_handler_params(&method.sig); + + match handler_type { + HandlerType::OnConnection => { + on_connect_handler = Some((method_name, params)); + } + HandlerType::OnDisconnect => { + on_disconnect_handler = Some((method_name, params)); + } + HandlerType::OnFallback => { + on_fallback_handler = Some((method_name, params)); + } + HandlerType::SubscribeMessage => { + let msg_name = + message_type.unwrap_or_else(|| "message".to_string()); + message_handlers.push((msg_name, method_name, params)); + } + } + } + } + } + + // Generate message handler codes that call the actual methods + // Note: For now we skip message handlers due to lifetime constraints + // Message handlers would need to be registered outside the socket setup + #[allow(unused)] + let message_handler_codes: Vec = vec![]; + + // Generate message handler registration code + let message_handler_registrations = message_handlers.iter().map(|(event_name, method_name, params)| { + // Build the closure parameters based on what the handler expects + let closure_params = if params.has_socket && params.has_data && params.has_ack { + quote! { socket: sword::prelude::websocket::SocketRef, + data: sword::prelude::websocket::Data, + ack: sword::prelude::websocket::AckSender } + } else if params.has_socket && params.has_data { + quote! { socket: sword::prelude::websocket::SocketRef, + data: sword::prelude::websocket::Data } + } else if params.has_data && params.has_ack { + quote! { _socket: sword::prelude::websocket::SocketRef, + data: sword::prelude::websocket::Data, + ack: sword::prelude::websocket::AckSender } + } else if params.has_socket && params.has_ack { + quote! { socket: sword::prelude::websocket::SocketRef, + _data: sword::prelude::websocket::Data, + ack: sword::prelude::websocket::AckSender } + } else if params.has_data { + quote! { _socket: sword::prelude::websocket::SocketRef, + data: sword::prelude::websocket::Data } + } else if params.has_socket { + quote! { socket: sword::prelude::websocket::SocketRef, + _data: sword::prelude::websocket::Data } + } else if params.has_ack { + quote! { _socket: sword::prelude::websocket::SocketRef, + _data: sword::prelude::websocket::Data, + ack: sword::prelude::websocket::AckSender } + } else { + quote! { _socket: sword::prelude::websocket::SocketRef, + _data: sword::prelude::websocket::Data } + }; + + // Build the method call parameters based on what the handler expects + // Standard order: SocketRef, Event, Data, AckSender + let mut call_params = vec![]; + if params.has_socket { + call_params.push(quote! { socket }); + } + if params.has_event { + call_params.push(quote! { sword::prelude::websocket::Event(#event_name.to_string()) }); + } + if params.has_data { + call_params.push(quote! { data }); + } + if params.has_ack { + call_params.push(quote! { ack }); + } + + quote! { + { + let state_for_msg = state_for_handler.clone(); + socket.on(#event_name, move |#closure_params| { + let state = state_for_msg.clone(); + async move { + match as ::sword::core::FromStateArc>::from_state_arc(&state) { + Ok(controller) => { + controller.#method_name(#(#call_params),*).await; + } + Err(e) => { + sword::__internal::tracing::error!("Failed to instantiate controller for message handler: {}", e); + } + } + } + }); + } + } + }).collect::>(); + + // Generate disconnect handler registration code + let disconnect_handler_registration = if let Some((disconnect_method, params)) = + on_disconnect_handler.as_ref() + { + let call_params = if params.has_socket { + quote! { sock } + } else { + quote! {} + }; + + quote! { + { + let state_for_disconnect = state_for_handler.clone(); + let socket_for_disconnect = socket.clone(); + socket.on_disconnect(move || { + let state = state_for_disconnect.clone(); + let sock = socket_for_disconnect.clone(); + async move { + match as ::sword::core::FromStateArc>::from_state_arc(&state) { + Ok(controller) => { + controller.#disconnect_method(#call_params).await; + } + Err(e) => { + sword::__internal::tracing::error!("Failed to instantiate controller for disconnect handler: {}", e); + } + } + } + }); + } + } + } else { + quote! {} + }; + + // Generate fallback handler registration code (catch-all for unhandled events) + let fallback_handler_registration = if let Some((fallback_method, params)) = + on_fallback_handler.as_ref() + { + // Build the closure parameters based on what the handler expects + let closure_params = if params.has_event && params.has_data { + quote! { event: sword::prelude::websocket::Event, + data: sword::prelude::websocket::Data } + } else if params.has_event { + quote! { event: sword::prelude::websocket::Event, + _data: sword::prelude::websocket::Data } + } else if params.has_data { + quote! { _event: sword::prelude::websocket::Event, + data: sword::prelude::websocket::Data } + } else { + quote! { _event: sword::prelude::websocket::Event, + _data: sword::prelude::websocket::Data } + }; + + // Build the method call parameters + let mut call_params = vec![]; + if params.has_event { + call_params.push(quote! { event }); + } + if params.has_data { + call_params.push(quote! { data }); + } + + quote! { + { + let state_for_fallback = state_for_handler.clone(); + // Use socketioxide's on_fallback method for catch-all event handling + socket.on_fallback(move |#closure_params| { + let state = state_for_fallback.clone(); + async move { + match as ::sword::core::FromStateArc>::from_state_arc(&state) { + Ok(controller) => { + controller.#fallback_method(#(#call_params),*).await; + } + Err(e) => { + sword::__internal::tracing::error!("Failed to instantiate controller for fallback handler: {}", e); + } + } + } + }); + } + } + } else { + quote! {} + }; + + // Generate the handler call code + let handler_call = if let Some((handler_name, params)) = + on_connect_handler.as_ref() + { + let call_params = if params.has_socket { + quote! { socket.clone() } + } else { + quote! {} + }; + + quote! { + match as ::sword::core::FromStateArc>::from_state_arc(&state_for_handler) { + Ok(controller) => { + controller.#handler_name(#call_params).await; + } + Err(e) => { + sword::__internal::tracing::error!("Failed to instantiate controller: {}", e); + } + } + } + } else { + quote! { + sword::__internal::tracing::info!(ns = socket.ns(), ?socket.id, "Socket.IO connected"); + } + }; + + let expanded = quote! { + #input + + impl ::sword::web::websocket::WebSocketProvider for #self_ty { + fn path() -> &'static str { + Box::leak(#path.into()) + } + + fn get_setup_fn(_: ::sword::core::State) -> ::sword::web::websocket::SocketSetupFn { + std::sync::Arc::new(|io: &::sword::web::websocket::SocketIo, state_from_call: ::sword::core::State| { + let app_state = state_from_call.clone(); // Clone state BEFORE ns() call + + io.ns(#path, move |socket: ::sword::web::websocket::SocketRef| { + let state_for_handler = app_state.clone(); // Clone again for the async block + + // Run connection handler (async) + Box::pin(async move { + // Call on_connect handler with captured state + #handler_call + + // Register message handlers + #(#message_handler_registrations)* + + // Register fallback handler (must be before disconnect to catch all unhandled events) + #fallback_handler_registration + + // Register disconnect handler + #disconnect_handler_registration + }) + }); + }) + } + + fn router(state: ::sword::core::State) -> ::sword::__internal::AxumRouter { + ::sword::__internal::AxumRouter::new().with_state(state) + } + } + }; + + expanded.into() +} diff --git a/sword-macros/src/websocket/mod.rs b/sword-macros/src/websocket/mod.rs new file mode 100644 index 0000000..4951d00 --- /dev/null +++ b/sword-macros/src/websocket/mod.rs @@ -0,0 +1,4 @@ +pub mod expand; +pub mod parsing; + +pub use expand::*; diff --git a/sword-macros/src/websocket/parsing.rs b/sword-macros/src/websocket/parsing.rs new file mode 100644 index 0000000..bd20081 --- /dev/null +++ b/sword-macros/src/websocket/parsing.rs @@ -0,0 +1,52 @@ +//! Parsing logic for WebSocket macros + +use syn::{ + Attribute, LitStr, Result as SynResult, + parse::{Parse, ParseStream}, +}; + +pub struct WebSocketPath { + pub path: String, +} + +impl Parse for WebSocketPath { + fn parse(input: ParseStream) -> SynResult { + let lit_str: LitStr = input.parse()?; + Ok(WebSocketPath { + path: lit_str.value(), + }) + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum HandlerType { + OnConnection, + OnDisconnect, + SubscribeMessage, + OnFallback, +} + +pub fn get_handler_type( + attrs: &[Attribute], +) -> Option<(HandlerType, Option)> { + for attr in attrs { + if attr.path().is_ident("on_connection") { + return Some((HandlerType::OnConnection, None)); + } + if attr.path().is_ident("on_disconnect") { + return Some((HandlerType::OnDisconnect, None)); + } + if attr.path().is_ident("on_fallback") { + return Some((HandlerType::OnFallback, None)); + } + if attr.path().is_ident("subscribe_message") { + if let Ok(message_type) = attr.parse_args::() { + return Some(( + HandlerType::SubscribeMessage, + Some(message_type.value()), + )); + } + } + } + None +} diff --git a/sword/Cargo.toml b/sword/Cargo.toml index 5723e4c..7ac5db7 100644 --- a/sword/Cargo.toml +++ b/sword/Cargo.toml @@ -36,6 +36,10 @@ byte-unit = "5.1.6" regex-lite = "0.1.7" validator = { workspace = true, optional = true } +socketioxide = { version = "0.17.2", optional = true } +rmpv = { version = "1.3.0", features = ["serde", "with-serde"], optional = true } +tracing = { workspace = true } + tower = "0.5.2" tower-http = { version = "0.6.6", features = ["limit", "timeout"] } tower-cookies = { version = "0.11.0", optional = true } @@ -57,4 +61,5 @@ multipart = ["axum/multipart", "dep:bytes"] cookies = ["dep:tower-cookies", "tower-cookies/signed", "tower-cookies/private"] helmet = ["dep:axum-helmet"] validator = ["dep:validator"] +websocket = ["dep:socketioxide","dep:rmpv"] hot-reload = ["dep:subsecond", "dep:dioxus-devtools", "sword-macros/hot-reload"] diff --git a/sword/src/core/application/app.rs b/sword/src/core/application/app.rs index de1906b..20d5df1 100644 --- a/sword/src/core/application/app.rs +++ b/sword/src/core/application/app.rs @@ -1,5 +1,7 @@ use axum::routing::Router; use axum_responses::http::HttpResponse; +#[cfg(feature = "websocket")] +use socketioxide::SocketIo; use tokio::net::TcpListener as Listener; use crate::core::{ @@ -13,13 +15,33 @@ use crate::core::{ /// the web server, routing, and application configuration. It provides a /// builder pattern for configuration and methods to run the application. pub struct Application { - router: Router, - config: Config, + pub(crate) router: Router, + pub(crate) config: Config, + pub(crate) state: crate::core::State, + #[cfg(feature = "websocket")] + pub(crate) socket_setups: + Vec<(&'static str, crate::web::websocket::SocketSetupFn)>, } impl Application { - pub fn new(router: Router, config: Config) -> Self { - Self { router, config } + pub fn new(router: Router, config: Config, state: crate::core::State) -> Self { + Self { + router, + config, + state, + #[cfg(feature = "websocket")] + socket_setups: Vec::new(), + } + } + + #[cfg(feature = "websocket")] + pub(crate) fn with_socket_setup( + mut self, + path: &'static str, + setup_fn: crate::web::websocket::SocketSetupFn, + ) -> Self { + self.socket_setups.push((path, setup_fn)); + self } /// Creates a new application builder for configuring the application. @@ -94,10 +116,32 @@ impl Application { HttpResponse::NotFound().message("The requested resource was not found") }); - axum::serve(listener, router) - .await - .map_err(|e| ApplicationError::ServerError { source: e }) - .expect("Internal server error"); + #[cfg(feature = "websocket")] + { + let (layer, io) = SocketIo::new_layer(); + + // Register all socket setups + for (_path, setup_fn) in &self.socket_setups { + // Call the setup function which handles namespace registration 💀 + // This allows handlers to be registered properly for each namespace + setup_fn(&io, self.state.clone()); + } + + let router = router.layer(layer); + + axum::serve(listener, router) + .await + .map_err(|e| ApplicationError::ServerError { source: e }) + .expect("Internal server error"); + } + + #[cfg(not(feature = "websocket"))] + { + axum::serve(listener, router) + .await + .map_err(|e| ApplicationError::ServerError { source: e }) + .expect("Internal server error"); + } } /// Runs the application server with graceful shutdown support. diff --git a/sword/src/core/application/builder.rs b/sword/src/core/application/builder.rs index f698b88..ee67109 100644 --- a/sword/src/core/application/builder.rs +++ b/sword/src/core/application/builder.rs @@ -37,7 +37,7 @@ use crate::{ /// .with_layer(tower_http::cors::CorsLayer::permissive()) /// .build(); /// ``` -#[derive(Debug, Clone)] +#[derive(Clone)] pub struct ApplicationBuilder { /// The internal Axum router that handles HTTP requests. router: Router, @@ -53,6 +53,10 @@ pub struct ApplicationBuilder { /// Flag to track if middlewares have been registered middlewares_registered: bool, + + /// Socket.IO setup configurations + #[cfg(feature = "websocket")] + socket_setups: Vec<(&'static str, crate::web::websocket::SocketSetupFn)>, } impl ApplicationBuilder { @@ -94,6 +98,8 @@ impl ApplicationBuilder { config, prefix: None, middlewares_registered: false, + #[cfg(feature = "websocket")] + socket_setups: Vec::new(), } } @@ -147,6 +153,65 @@ impl ApplicationBuilder { config: self.config, prefix: self.prefix, middlewares_registered: self.middlewares_registered, + #[cfg(feature = "websocket")] + socket_setups: self.socket_setups, + } + } + + /// Registers a WebSocket gateway in the application. + /// + /// This could be used to add WebSocket support. + /// Currently is used just with Socket.IO via `socketioxide`. + /// + /// ### Type Parameters + /// + /// * `W` - A type implementing `WebSocketGateway` that defines the WebSocket handlers + /// + /// ### Example + /// + /// ```rust,ignore + /// use sword::prelude::*; + /// + /// #[web_socket_gateway] + /// struct SocketController; + /// + /// #[web_socket("/socket")] + /// impl SocketController { + /// #[on_connection] + /// async fn on_connect(&self, socket: SocketRef) { + /// println!("Connected"); + /// } + /// } + /// + /// let app = Application::builder() + /// .with_socket::() + /// .build(); + /// ``` + #[cfg(feature = "websocket")] + pub fn with_socket(self) -> Self + where + W: crate::web::websocket::WebSocketProvider + Clone + Send + Sync + 'static, + W: crate::core::Build, + { + let path = W::path(); + + if let Ok(controller) = W::build(&self.state) { + let _ = self.state.insert(controller); + } + + let setup_fn = W::get_setup_fn(self.state.clone()); + + let mut socket_setups = self.socket_setups; + socket_setups.push((path, setup_fn)); + + Self { + router: self.router, + state: self.state, + config: self.config, + prefix: self.prefix, + middlewares_registered: self.middlewares_registered, + #[cfg(feature = "websocket")] + socket_setups, } } @@ -188,6 +253,8 @@ impl ApplicationBuilder { config: self.config, prefix: self.prefix, middlewares_registered: self.middlewares_registered, + #[cfg(feature = "websocket")] + socket_setups: self.socket_setups, } } @@ -210,6 +277,8 @@ impl ApplicationBuilder { config: self.config, prefix: self.prefix, middlewares_registered: self.middlewares_registered, + #[cfg(feature = "websocket")] + socket_setups: self.socket_setups, } } @@ -224,6 +293,8 @@ impl ApplicationBuilder { config: self.config, prefix: Some(prefix.into()), middlewares_registered: self.middlewares_registered, + #[cfg(feature = "websocket")] + socket_setups: self.socket_setups, } } @@ -271,7 +342,16 @@ impl ApplicationBuilder { router = Router::new().nest(prefix, router); } - Application::new(router, self.config) + let mut app = Application::new(router, self.config, self.state.clone()); + + #[cfg(feature = "websocket")] + { + for (path, setup_fn) in self.socket_setups { + app = app.with_socket_setup(path, setup_fn); + } + } + + app } } diff --git a/sword/src/lib.rs b/sword/src/lib.rs index dcdff81..d3d4f79 100644 --- a/sword/src/lib.rs +++ b/sword/src/lib.rs @@ -25,7 +25,10 @@ pub mod prelude; pub mod web; -pub use sword_macros::main; +pub use sword_macros::{ + main, on_connection, on_disconnect, subscribe_message, web_socket, + web_socket_gateway, +}; #[doc(hidden)] pub mod __internal { @@ -46,6 +49,8 @@ pub mod __internal { pub use tokio::runtime as tokio_runtime; + pub use tracing; + #[cfg(feature = "hot-reload")] pub use dioxus_devtools; #[cfg(feature = "hot-reload")] diff --git a/sword/src/prelude.rs b/sword/src/prelude.rs index 51b8ba3..bda6a5b 100644 --- a/sword/src/prelude.rs +++ b/sword/src/prelude.rs @@ -73,5 +73,8 @@ pub use crate::web::cookies::*; #[cfg(feature = "multipart")] pub use crate::web::multipart; +#[cfg(feature = "websocket")] +pub use crate::web::websocket::*; + #[cfg(feature = "validator")] pub use crate::web::request_validator::*; diff --git a/sword/src/web/mod.rs b/sword/src/web/mod.rs index d7152cb..b4a13b2 100644 --- a/sword/src/web/mod.rs +++ b/sword/src/web/mod.rs @@ -3,12 +3,16 @@ mod middleware; mod request; mod response; +#[cfg(feature = "websocket")] +pub mod websocket; + pub use axum::http::{Method, StatusCode, header}; pub use controller::*; pub use middleware::*; pub use request::{Request, RequestError}; pub use response::*; +pub use websocket::*; #[cfg(feature = "multipart")] pub use request::multipart; diff --git a/sword/src/web/websocket/handler.rs b/sword/src/web/websocket/handler.rs new file mode 100644 index 0000000..3d62b69 --- /dev/null +++ b/sword/src/web/websocket/handler.rs @@ -0,0 +1,39 @@ +//! WebSocket handler traits and utilities + +use crate::core::State; +use axum::routing::Router; +#[cfg(feature = "websocket")] +use socketioxide::SocketIo; +use std::sync::Arc; + +/// Trait for WebSocket gateway controllers +/// +/// This trait is automatically implemented by types annotated with `#[web_socket_gateway]` +/// and provides the necessary methods for registering WebSocket handlers with the application. +#[cfg(feature = "websocket")] +pub trait WebSocketGateway: Send + Sync + 'static { + /// Creates a router for the WebSocket handlers + fn router(state: State) -> Router; +} + +/// Trait for providing WebSocket routes +/// +/// This trait is implemented by types annotated with the `#[web_socket]` macro +/// and contains the actual handler implementations for WebSocket events. +#[cfg(feature = "websocket")] +pub trait WebSocketProvider: Send + Sync + 'static { + /// Returns the path where the WebSocket is mounted + fn path() -> &'static str; + + /// Returns a setup function that configures handlers on the SocketIo instance + /// This should register the namespace and set up all handlers + fn get_setup_fn(state: State) -> SocketSetupFn; + + /// Creates a router for the WebSocket + fn router(state: State) -> Router; +} + +/// Type for Socket.IO setup functions +/// Takes both SocketIo reference and the application state for dependency injection +#[cfg(feature = "websocket")] +pub type SocketSetupFn = Arc; diff --git a/sword/src/web/websocket/mod.rs b/sword/src/web/websocket/mod.rs new file mode 100644 index 0000000..96afa95 --- /dev/null +++ b/sword/src/web/websocket/mod.rs @@ -0,0 +1,37 @@ +//! SocketIO support for Sword. +//! +//! This module provides types and traits for handling SocketIO connections +//! via `socketioxide`. +//! +//! # Example +//! +//! ```rust,ignore +//! use sword::prelude::*; +//! +//! #[web_socket_gateway] +//! struct ChatSocket; +//! +//! #[web_socket("/chat")] +//! impl ChatSocket { +//! #[on_connection] +//! async fn on_connect(&self, socket: SocketRef) { +//! println!("Client connected: {}", socket.id); +//! } +//! +//! #[subscribe_message("message")] +//! async fn on_message(&self, socket: SocketRef, Data(msg): Data) { +//! socket.inner().emit("message", msg).ok(); +//! } +//! +//! #[on_disconnect] +//! async fn on_disconnect(&self, socket: WebSocket) { +//! println!("Client disconnected: {}", socket.id()); +//! } +//! } +//! ``` + +pub mod handler; +pub mod types; + +pub use handler::*; +pub use types::*; diff --git a/sword/src/web/websocket/types.rs b/sword/src/web/websocket/types.rs new file mode 100644 index 0000000..386ad0a --- /dev/null +++ b/sword/src/web/websocket/types.rs @@ -0,0 +1,8 @@ +use socketioxide; + +// Re-export socketioxide types directly +pub use rmpv::{Value, ValueRef}; +pub use socketioxide::SocketIo; +pub use socketioxide::extract::{AckSender, Data, Event}; + +pub use socketioxide::extract::SocketRef;