diff --git a/.cargo/config.toml b/.cargo/config.toml index bff29e6..9c68c3a 100644 --- a/.cargo/config.toml +++ b/.cargo/config.toml @@ -1,2 +1,5 @@ [build] rustflags = ["--cfg", "tokio_unstable"] + +[registries.crates-io] +protocol = "sparse" diff --git a/Cargo.lock b/Cargo.lock index 0d5e312..33a8c97 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -851,7 +851,6 @@ dependencies = [ "protocol", "rand", "reqwest", - "serde_json", "sha1", "tokio", "tracing", @@ -1390,7 +1389,7 @@ dependencies = [ "if-addrs", "log", "multimap", - "nix", + "nix 0.23.2", "rand", "socket2", "thiserror", @@ -1733,6 +1732,18 @@ dependencies = [ "memoffset", ] +[[package]] +name = "nix" +version = "0.26.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfdda3d196821d6af13126e40375cdf7da646a96114af134d5f417a9a1dc8e1a" +dependencies = [ + "bitflags", + "cfg-if", + "libc", + "static_assertions", +] + [[package]] name = "nom" version = "7.1.3" @@ -2170,6 +2181,7 @@ dependencies = [ "axum", "clap", "common", + "nix 0.26.2", "poise", "protocol", "serde_json", @@ -2656,6 +2668,12 @@ dependencies = [ "lock_api", ] +[[package]] +name = "static_assertions" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" + [[package]] name = "streamcatcher" version = "1.0.1" diff --git a/common/src/util.rs b/common/src/util.rs index 4c54c64..91d1e10 100644 --- a/common/src/util.rs +++ b/common/src/util.rs @@ -81,14 +81,10 @@ pub async fn ctrl_c() { } #[cfg(unix)] -pub async fn ctrl_c_and_pipe() { +pub async fn usr1() { use tokio::signal::unix::{signal, SignalKind}; - let others = ctrl_c(); - let mut pipe = signal(SignalKind::pipe()).unwrap(); - tokio::select! { - _ = others => {} - _ = pipe.recv() => {} - }; + let mut usr1 = signal(SignalKind::user_defined1()).unwrap(); + let _ = usr1.recv().await; } #[cfg(test)] diff --git a/forwarder/Cargo.toml b/forwarder/Cargo.toml index 3b5866d..c7cc810 100644 --- a/forwarder/Cargo.toml +++ b/forwarder/Cargo.toml @@ -6,6 +6,9 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +protocol = { path = "../protocol" } +common = { path = "../common" } + anyhow = { version = "1.0.70", features = ["backtrace"] } clap = { version = "4.2.2", features = ["derive"] } futures-util = "0.3.28" @@ -15,11 +18,7 @@ reqwest = { version = "0.11.16", default_features = false, features = [ "rustls-tls", "json", ] } -serde_json = "1.0.96" sha1 = "0.10.5" tokio = { version = "1.27.0", features = ["full"] } tracing = "0.1.37" rand = "0.8.5" - -protocol = { path = "../protocol" } -common = { path = "../common" } diff --git a/player/Cargo.toml b/player/Cargo.toml index f589f3c..98271fb 100644 --- a/player/Cargo.toml +++ b/player/Cargo.toml @@ -5,6 +5,8 @@ edition = "2021" [dependencies] +common = { path = "../common" } + anyhow = { version = "1.0.70", features = ["backtrace"] } clap = { version = "4.2.2", features = ["derive", "env"] } hex = "0.4.3" @@ -13,4 +15,3 @@ sha1 = "0.10.5" tokio = { version = "1.27.0", features = ["full"] } tracing = "0.1.37" serde_json = "1.0.96" -common = { path = "../common" } diff --git a/player/src/main.rs b/player/src/main.rs index 1ce65b8..12b9276 100644 --- a/player/src/main.rs +++ b/player/src/main.rs @@ -84,11 +84,11 @@ pub async fn main() -> Result<()> { _ = &mut spirc_task => { tracing::debug!("spirc task finished"); } - _ = util::ctrl_c_and_pipe() => { - // what happens is songbird sends SIGKILL(9) to the last child -- gstreamer. presumably then its stdin is closed, which means our stdout is closed -> SIGPIPE - // actually what happens is the Player fails to write to stoud and then calls std::process::exit(1) :( - // TODO: what can we do about that - tracing::debug!("received ctrl-c or pipe"); + _ = util::ctrl_c() => { + tracing::debug!("received ctrl-c"); + }, + _ = util::usr1() => { + tracing::debug!("got SIGUSR1, shutting down") }, }; diff --git a/receiver/Cargo.toml b/receiver/Cargo.toml index 342439f..7099e06 100644 --- a/receiver/Cargo.toml +++ b/receiver/Cargo.toml @@ -4,6 +4,9 @@ version = "0.1.0" edition = "2021" [dependencies] +protocol = { path = "../protocol" } +common = { path = "../common" } + songbird = { version = "0.3" } anyhow = { version = "1.0.70", features = ["backtrace"] } clap = { version = "4.2.2", features = ["derive", "env"] } @@ -11,6 +14,5 @@ tokio = { version = "1.27.0", features = ["full"] } tracing = "0.1.37" serde_json = "1.0.96" axum = "0.6.15" -protocol = { path = "../protocol" } -common = { path = "../common" } poise = "0.5.3" +nix = { version = "0.26.2", default-features = false, features = ["signal"] } diff --git a/receiver/src/bot.rs b/receiver/src/bot.rs index 2a84901..cfca0ff 100644 --- a/receiver/src/bot.rs +++ b/receiver/src/bot.rs @@ -1,9 +1,9 @@ use std::{ process::{Command, Stdio}, - sync::{Arc, RwLock}, + sync::{Arc, Mutex, RwLock}, }; -use anyhow::Result; +use anyhow::{anyhow, Context as _, Result}; use clap::Parser; use poise::serenity_prelude::GatewayIntents; @@ -20,10 +20,11 @@ pub struct BotOptions { discord_token: String, } -// User data, which is stored and accessible in all command invocations +#[derive(Debug)] struct Data { bot_options: BotOptions, creds_registry: Arc>, + currently_playing_pid: Arc>>, } type Error = anyhow::Error; type Context<'a> = poise::Context<'a, Data, Error>; @@ -63,6 +64,7 @@ pub async fn run_bot(opts: BotOptions, stream_registry: Arc, #[description = "Stream key"] key: Strin let input = Input::new(true, reader, Codec::Pcm, Container::Raw, None); + // TODO: send player a signal on stop, so it can shut down gracefully before it's Dropped + let mut call_handler = call_handler_lock.lock().await; call_handler.play_source(input); @@ -210,6 +214,19 @@ async fn stop(ctx: Context<'_>) -> Result<()> { } Some(g) => g, }; + + { + let pid = { + let mut pid_mu = ctx.data().currently_playing_pid.lock().unwrap(); + pid_mu.take() + }; + if let Some(pid) = pid { + if let Err(e) = kill_player(pid as _).await.context("killing player") { + tracing::error!(?e, "failed to kill player"); + } + } + } + let voice_manager = songbird::get(ctx.serenity_context()).await.unwrap().clone(); let call_handler_lock = voice_manager.get(guild.id); if let Some(call_handler_lock) = call_handler_lock { @@ -226,3 +243,93 @@ async fn stop(ctx: Context<'_>) -> Result<()> { async fn restart(_ctx: Context<'_>) -> Result<()> { std::process::exit(0); } + +#[derive(Debug, Clone, Copy)] +enum HowKilled { + Usr1, + Term, + Kill, +} +/// gracefully kill player by sending it SIGUSR1, waiting, then sending it SIGTERM +async fn kill_player(pid: u32) -> Result { + use nix::{sys::signal::Signal, unistd::Pid}; + + let pid = Pid::from_raw(pid as i32); + tracing::debug!(?pid, "asking player to stop"); + nix::sys::signal::kill(pid, Signal::SIGUSR1).context("sending usr1")?; + + // wait for it to exit, or timeout + tokio::select! { + _ = tokio::time::sleep(std::time::Duration::from_secs(1)) => {}, + _ = async { tokio::task::spawn_blocking(move || nix::sys::wait::waitpid(pid, None).map_err(|e| anyhow!("error waiting: {:?}", e))).await? } => { + return Ok(HowKilled::Usr1); + }, + } + + tracing::warn!("player did not exit in time after USR1; sending TERM"); + nix::sys::signal::kill(pid, Signal::SIGTERM)?; + + tokio::select! { + _ = tokio::time::sleep(std::time::Duration::from_secs(1)) => {}, + _ = async { tokio::task::spawn_blocking(move || nix::sys::wait::waitpid(pid, None).map_err(|e| anyhow!("error waiting: {:?}", e))).await? } => { + return Ok(HowKilled::Term); + }, + } + + tracing::warn!("player did not exit in time after TERM; sending KILL"); + nix::sys::signal::kill(pid, Signal::SIGKILL)?; + + Ok::<_, anyhow::Error>(HowKilled::Kill) +} + +#[cfg(test)] +mod tests { + use super::*; + use tokio::io::AsyncBufReadExt; + + #[tokio::test] + async fn kill_player_works() -> Result<()> { + common::util::setup_logging()?; + + let sh = r#" +trap 'echo usr1' SIGUSR1; +trap 'echo term' SIGTERM; +echo setup +while true; do sleep 0.5; done; +"#; + let mut child = tokio::process::Command::new("bash") + .args(&["-c", sh]) + .stdout(Stdio::piped()) + .spawn()?; + let pid = child.id().unwrap(); + let out = child.stdout.take().unwrap(); + + let (setup_tx, setup_rx) = tokio::sync::oneshot::channel(); + + let output = tokio::task::spawn(async move { + let mut setup_tx = Some(setup_tx); + let mut out_str = String::new(); + let mut reader = tokio::io::BufReader::new(out).lines(); + while let Some(line) = reader.next_line().await.unwrap() { + if line.trim() == "setup" { + setup_tx.take().unwrap().send(()).unwrap(); + } + tracing::debug!("{}", line); + out_str.push_str(&format!("{}\n", line)); + } + out_str + }); + + // wait for child's handlers to be setup + setup_rx.await.unwrap(); + + tracing::debug!(?pid, "started player"); + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + kill_player(pid).await?; + + let out_str = output.await?; + assert_eq!(out_str, "setup\nusr1\nterm\n"); + + Ok(()) + } +}