Skip to content

feat: optionally limit access to nodes listed in authorized_keys file #68

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,24 @@ You can now browse the website on port 3001.

# Advanced features

## Limiting access

You can limit access to a dumbpipe listener through a keys file, similar to the `authorized_keys` file that SSH uses.
You can put the file wherever you want, e.g. at `~/.dumbpipe/authorized_keys`. For the file to be used, and thus
access to be limited, specify the file path with the `--authorized-keys` (or `-a`) when launching dumbpipe.
When authorization is set, only connections from nodes listed in the file will be accepted.

Here's an example file:
```
# dumbpipe authorized nodes
148449487b53bb90382927634114457ef90d2a63127200fd8816a8dffb9d48c6 some-server
3827f5124d03d10f2f344d319a88c64c198c4db1335560ea6aad41ce2fb7c311 devbox
```

The file must contain a list of hex-encoded node ids, seperated by newlines.
The node ids may be followed by a comment, separated by a space from the encoded node id.
Lines starting with `#` are ignored and can be used as comments.

## Custom ALPNs

Dumbpipe has an expert feature to specify a custom [ALPN](https://en.wikipedia.org/wiki/Application-Layer_Protocol_Negotiation) string. You can use it to interact with
Expand Down
145 changes: 117 additions & 28 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,24 @@
//! Command line arguments.
use clap::{Parser, Subcommand};
use dumbpipe::NodeTicket;
use iroh::{endpoint::Connecting, Endpoint, NodeAddr, SecretKey, Watcher};
use n0_snafu::{Result, ResultExt};
use iroh::{
endpoint::{Connecting, Connection},
Endpoint, NodeAddr, NodeId, SecretKey, Watcher,
};
use n0_snafu::{format_err, Result, ResultExt};
use std::{
io,
net::{SocketAddr, SocketAddrV4, SocketAddrV6, ToSocketAddrs},
path::{Path, PathBuf},
str::FromStr,
sync::Arc,
};
use tokio::{
io::{AsyncRead, AsyncWrite, AsyncWriteExt},
select,
};
use tokio_util::sync::CancellationToken;
use tracing::info;

/// Create a dumb pipe between two machines, using an iroh magicsocket.
///
Expand Down Expand Up @@ -122,19 +128,47 @@ fn parse_alpn(alpn: &str) -> Result<Vec<u8>> {
})
}

/// Arguments shared among commands accepting connections.
#[derive(Parser, Debug)]
pub struct CommonAcceptArgs {
/// Optionally limit access to node ids listed in this file.
///
/// When set, only node ids listed in the file will be allowed to connect.
/// Other connections will be rejected.
///
/// The file must contain one hex-encoded node id per line. The node id may be followed
/// by a comment, separated with a space. Lines starting with `#` are ignored and may
/// be used as comments.
#[clap(short = 'a', long, value_name = "FILE")]
pub authorized_keys: Option<PathBuf>,
}

impl CommonAcceptArgs {
async fn authorized_keys(&self) -> Result<Option<AuthorizedKeys>> {
if let Some(ref path) = self.authorized_keys {
Ok(Some(AuthorizedKeys::load(path).await?))
} else {
Ok(None)
}
}
}

#[derive(Parser, Debug)]
pub struct ListenArgs {
#[clap(flatten)]
pub common: CommonArgs,
#[clap(flatten)]
pub accept: CommonAcceptArgs,
}

#[derive(Parser, Debug)]
pub struct ListenTcpArgs {
#[clap(long)]
pub host: String,

#[clap(flatten)]
pub common: CommonArgs,
#[clap(flatten)]
pub accept: CommonAcceptArgs,
}

#[derive(Parser, Debug)]
Expand Down Expand Up @@ -267,6 +301,7 @@ async fn forward_bidi(

async fn listen_stdio(args: ListenArgs) -> Result<()> {
let secret_key = get_or_create_secret()?;
let authorized_keys = args.accept.authorized_keys().await?;
let mut builder = Endpoint::builder()
.alpns(vec![args.common.alpn()?])
.secret_key(secret_key);
Expand All @@ -277,6 +312,7 @@ async fn listen_stdio(args: ListenArgs) -> Result<()> {
builder = builder.bind_addr_v6(addr);
}
let endpoint = builder.bind().await?;
eprintln!("endpoint bound with node id {}", endpoint.node_id());
// wait for the endpoint to figure out its address before making a ticket
endpoint.home_relay().initialized().await?;
let node = endpoint.node_addr().initialized().await?;
Expand Down Expand Up @@ -306,7 +342,12 @@ async fn listen_stdio(args: ListenArgs) -> Result<()> {
}
};
let remote_node_id = &connection.remote_node_id()?;
tracing::info!("got connection from {}", remote_node_id);
info!("got connection from {}", remote_node_id);
if let Some(ref authorized_keys) = authorized_keys {
if authorized_keys.authorize(&connection).is_err() {
continue;
}
}
let (s, mut r) = match connection.accept_bi().await {
Ok(x) => x,
Err(cause) => {
Expand All @@ -315,14 +356,14 @@ async fn listen_stdio(args: ListenArgs) -> Result<()> {
continue;
}
};
tracing::info!("accepted bidi stream from {}", remote_node_id);
info!("accepted bidi stream from {}", remote_node_id);
if !args.common.is_custom_alpn() {
// read the handshake and verify it
let mut buf = [0u8; dumbpipe::HANDSHAKE.len()];
r.read_exact(&mut buf).await.e()?;
snafu::ensure_whatever!(buf == dumbpipe::HANDSHAKE, "invalid handshake");
}
tracing::info!("forwarding stdin/stdout to {}", remote_node_id);
info!("forwarding stdin/stdout to {}", remote_node_id);
forward_bidi(tokio::io::stdin(), tokio::io::stdout(), r, s).await?;
// stop accepting connections after the first successful one
break;
Expand All @@ -341,23 +382,24 @@ async fn connect_stdio(args: ConnectArgs) -> Result<()> {
builder = builder.bind_addr_v6(addr);
}
let endpoint = builder.bind().await?;
eprintln!("endpoint bound with node id {}", endpoint.node_id());
let addr = args.ticket.node_addr();
let remote_node_id = addr.node_id;
// connect to the node, try only once
let connection = endpoint.connect(addr.clone(), &args.common.alpn()?).await?;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be reasonable to check for the 403 error code here to provide a nice error message when you're not authorized. Perhaps saying something like "The remote node denied access: Your node ID is not configured as authorized in the remote authorized_keys file."

tracing::info!("connected to {}", remote_node_id);
info!("connected to {}", remote_node_id);
// open a bidi stream, try only once
let (mut s, r) = connection.open_bi().await.e()?;
tracing::info!("opened bidi stream to {}", remote_node_id);
let (mut send, recv) = connection.open_bi().await.e()?;
info!("opened bidi stream to {}", remote_node_id);
// send the handshake unless we are using a custom alpn
// when using a custom alpn, evertyhing is up to the user
if !args.common.is_custom_alpn() {
// the connecting side must write first. we don't know if there will be something
// on stdin, so just write a handshake.
s.write_all(&dumbpipe::HANDSHAKE).await.e()?;
send.write_all(&dumbpipe::HANDSHAKE).await.e()?;
}
tracing::info!("forwarding stdin/stdout to {}", remote_node_id);
forward_bidi(tokio::io::stdin(), tokio::io::stdout(), r, s).await?;
info!("forwarding stdin/stdout to {}", remote_node_id);
forward_bidi(tokio::io::stdin(), tokio::io::stdout(), recv, send).await?;
tokio::io::stdout().flush().await.e()?;
Ok(())
}
Expand All @@ -377,14 +419,12 @@ async fn connect_tcp(args: ConnectTcpArgs) -> Result<()> {
builder = builder.bind_addr_v6(addr);
}
let endpoint = builder.bind().await.context("unable to bind magicsock")?;
tracing::info!("tcp listening on {:?}", addrs);
let tcp_listener = match tokio::net::TcpListener::bind(addrs.as_slice()).await {
Ok(tcp_listener) => tcp_listener,
Err(cause) => {
tracing::error!("error binding tcp socket to {:?}: {}", addrs, cause);
return Ok(());
}
};
eprintln!("endpoint bound with node id {}", endpoint.node_id());
let tcp_listener = tokio::net::TcpListener::bind(addrs.as_slice())
.await
.with_context(|| format!("error binding tcp socket to {:?}", addrs.as_slice()))?;
info!("tcp listening on {:?}", addrs.as_slice());

async fn handle_tcp_accept(
next: io::Result<(tokio::net::TcpStream, SocketAddr)>,
addr: NodeAddr,
Expand All @@ -394,7 +434,7 @@ async fn connect_tcp(args: ConnectTcpArgs) -> Result<()> {
) -> Result<()> {
let (tcp_stream, tcp_addr) = next.context("error accepting tcp connection")?;
let (tcp_recv, tcp_send) = tcp_stream.into_split();
tracing::info!("got tcp connection from {}", tcp_addr);
info!("got tcp connection from {}", tcp_addr);
let remote_node_id = addr.node_id;
let connection = endpoint
.connect(addr, alpn)
Expand All @@ -412,8 +452,9 @@ async fn connect_tcp(args: ConnectTcpArgs) -> Result<()> {
magic_send.write_all(&dumbpipe::HANDSHAKE).await.e()?;
}
forward_bidi(tcp_recv, tcp_send, magic_recv, magic_send).await?;
Ok::<_, n0_snafu::Error>(())
Ok(())
}

let addr = args.ticket.node_addr();
loop {
// also wait for ctrl-c here so we can use it before accepting a connection
Expand All @@ -433,7 +474,7 @@ async fn connect_tcp(args: ConnectTcpArgs) -> Result<()> {
// log error at warn level
//
// we should know about it, but it's not fatal
tracing::warn!("error handling connection: {}", cause);
tracing::warn!("error handling connection: {:#}", cause);
}
});
}
Expand All @@ -447,6 +488,7 @@ async fn listen_tcp(args: ListenTcpArgs) -> Result<()> {
Err(e) => snafu::whatever!("invalid host string {}: {}", args.host, e),
};
let secret_key = get_or_create_secret()?;
let authorized_keys = args.accept.authorized_keys().await?;
let mut builder = Endpoint::builder()
.alpns(vec![args.common.alpn()?])
.secret_key(secret_key);
Expand All @@ -457,13 +499,15 @@ async fn listen_tcp(args: ListenTcpArgs) -> Result<()> {
builder = builder.bind_addr_v6(addr);
}
let endpoint = builder.bind().await?;
eprintln!("endpoint bound with node id {}", endpoint.node_id());
// wait for the endpoint to figure out its address before making a ticket
endpoint.home_relay().initialized().await?;
let node_addr = endpoint.node_addr().initialized().await?;
let mut short = node_addr.clone();
let ticket = NodeTicket::new(node_addr);
short.direct_addresses.clear();
let short = NodeTicket::new(short);
println!("ticket {short:?}");

// print the ticket on stderr so it doesn't interfere with the data itself
//
Expand All @@ -474,23 +518,27 @@ async fn listen_tcp(args: ListenTcpArgs) -> Result<()> {
if args.common.verbose > 0 {
eprintln!("or:\ndumbpipe connect-tcp {short}");
}
tracing::info!("node id is {}", ticket.node_addr().node_id);
tracing::info!("derp url is {:?}", ticket.node_addr().relay_url);
info!("node id is {}", ticket.node_addr().node_id);
info!("relay url is {:?}", ticket.node_addr().relay_url);

// handle a new incoming connection on the magic endpoint
async fn handle_magic_accept(
connecting: Connecting,
addrs: Vec<std::net::SocketAddr>,
handshake: bool,
authorized_keys: Option<AuthorizedKeys>,
) -> Result<()> {
let connection = connecting.await.context("error accepting connection")?;
let remote_node_id = &connection.remote_node_id()?;
tracing::info!("got connection from {}", remote_node_id);
info!("got connection from {}", remote_node_id);
if let Some(ref authorized_keys) = authorized_keys {
authorized_keys.authorize(&connection)?;
}
let (s, mut r) = connection
.accept_bi()
.await
.context("error accepting stream")?;
tracing::info!("accepted bidi stream from {}", remote_node_id);
info!("accepted bidi stream from {}", remote_node_id);
if handshake {
// read the handshake and verify it
let mut buf = [0u8; dumbpipe::HANDSHAKE.len()];
Expand Down Expand Up @@ -521,8 +569,11 @@ async fn listen_tcp(args: ListenTcpArgs) -> Result<()> {
};
let addrs = addrs.clone();
let handshake = !args.common.is_custom_alpn();
let authorized_keys = authorized_keys.clone();
tokio::spawn(async move {
if let Err(cause) = handle_magic_accept(connecting, addrs, handshake).await {
if let Err(cause) =
handle_magic_accept(connecting, addrs, handshake, authorized_keys).await
{
// log error at warn level
//
// we should know about it, but it's not fatal
Expand All @@ -533,6 +584,44 @@ async fn listen_tcp(args: ListenTcpArgs) -> Result<()> {
Ok(())
}

#[derive(Debug, Clone)]
struct AuthorizedKeys(Arc<Vec<NodeId>>);

impl AuthorizedKeys {
async fn load(path: impl AsRef<Path>) -> Result<Self> {
let path = path.as_ref();
let keys: Result<Vec<NodeId>> = tokio::fs::read_to_string(path)
.await
.with_context(|| format!("failed to read authorized keys file at {}", path.display()))?
.lines()
.filter_map(|line| line.split_whitespace().next())
.filter(|str| !str.starts_with('#'))
.map(|str| {
NodeId::from_str(str).with_context(|| {
format!("failed to parse node id `{str}` from authorized keys file")
})
})
.collect();
let keys = keys?;
info!("authorization is enabled: {} nodes authorized.", keys.len());
Ok(Self(Arc::new(keys)))
}

fn authorize(&self, connection: &Connection) -> Result<()> {
let remote = connection.remote_node_id()?;
if !self.0.contains(&remote) {
connection.close(403u32.into(), b"unauthorized");
info!(
remote = %remote.fmt_short(),
"rejecting connection: unauthorized",
);
Err(format_err!("connection rejected: unauthorized"))
} else {
Ok(())
}
}
}

#[tokio::main]
async fn main() -> Result<()> {
tracing_subscriber::fmt::init();
Expand Down
18 changes: 9 additions & 9 deletions tests/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ fn connect_listen_happy() {
.stderr_to_stdout() //
.reader()
.unwrap();
// read the first 3 lines of the header, and parse the last token as a ticket
let header = read_ascii_lines(3, &mut listen).unwrap();
// read the first 4 lines of the header, and parse the last token as a ticket
let header = read_ascii_lines(4, &mut listen).unwrap();
let header = String::from_utf8(header).unwrap();
let ticket = header.split_ascii_whitespace().last().unwrap();
let ticket = NodeTicket::from_str(ticket).unwrap();
Expand Down Expand Up @@ -105,8 +105,8 @@ fn connect_listen_custom_alpn_happy() {
.stderr_to_stdout() //
.reader()
.unwrap();
// read the first 3 lines of the header, and parse the last token as a ticket
let header = read_ascii_lines(3, &mut listen).unwrap();
// read the first 4 lines of the header, and parse the last token as a ticket
let header = read_ascii_lines(4, &mut listen).unwrap();
let header = String::from_utf8(header).unwrap();
let ticket = header.split_ascii_whitespace().last().unwrap();
let ticket = NodeTicket::from_str(ticket).unwrap();
Expand Down Expand Up @@ -149,8 +149,8 @@ fn connect_listen_ctrlc_connect() {
.stderr_to_stdout() //
.reader()
.unwrap();
// read the first 3 lines of the header, and parse the last token as a ticket
let header = read_ascii_lines(3, &mut listen).unwrap();
// read the first 4 lines of the header, and parse the last token as a ticket
let header = read_ascii_lines(4, &mut listen).unwrap();
let header = String::from_utf8(header).unwrap();
let ticket = header.split_ascii_whitespace().last().unwrap();
let ticket = NodeTicket::from_str(ticket).unwrap();
Expand Down Expand Up @@ -189,8 +189,8 @@ fn connect_listen_ctrlc_listen() {
.stderr_to_stdout()
.reader()
.unwrap();
// read the first 3 lines of the header, and parse the last token as a ticket
let header = read_ascii_lines(3, &mut listen).unwrap();
// read the first 4 lines of the header, and parse the last token as a ticket
let header = read_ascii_lines(4, &mut listen).unwrap();
let header = String::from_utf8(header).unwrap();
let ticket = header.split_ascii_whitespace().last().unwrap();
let ticket = NodeTicket::from_str(ticket).unwrap();
Expand Down Expand Up @@ -267,7 +267,7 @@ fn connect_tcp_happy() {
.stderr_to_stdout() //
.reader()
.unwrap();
let header = read_ascii_lines(3, &mut listen).unwrap();
let header = read_ascii_lines(4, &mut listen).unwrap();
let header = String::from_utf8(header).unwrap();
let ticket = header.split_ascii_whitespace().last().unwrap();
let ticket = NodeTicket::from_str(ticket).unwrap();
Expand Down
Loading