diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 20792f4d..3e4ee4e4 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -34,6 +34,7 @@ jobs: - name: Run tests run: | cargo check --features multiplex + cargo check --features rustls,native-tls,vendored cargo test test-linux-aarch64: @@ -48,6 +49,7 @@ jobs: - name: Run tests run: | cargo check --features multiplex + cargo check --features rustls,native-tls,vendored cargo test test-macos: @@ -62,6 +64,7 @@ jobs: - name: Run tests run: | cargo check --features multiplex + cargo check --features rustls,native-tls,vendored cargo test test-windows: @@ -76,6 +79,7 @@ jobs: - name: Run tests run: | cargo check --features multiplex + cargo check --features rustls,native-tls,vendored cargo test lint: diff --git a/Cargo.lock b/Cargo.lock index 15fd0710..85c7af3f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -964,6 +964,19 @@ dependencies = [ "tower-service", ] +[[package]] +name = "hyper-tls" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6183ddfa99b85da61a140bea0efc93fdf56ceaa041b37d553518030827f9905" +dependencies = [ + "bytes", + "hyper 0.14.28", + "native-tls", + "tokio", + "tokio-native-tls", +] + [[package]] name = "hyper-util" version = "0.1.2" @@ -1449,6 +1462,15 @@ version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" +[[package]] +name = "openssl-src" +version = "300.2.1+3.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fe476c29791a5ca0d1273c697e96085bbabbbea2ef7afd5617e78a4b40332d3" +dependencies = [ + "cc", +] + [[package]] name = "openssl-sys" version = "0.9.97" @@ -1457,6 +1479,7 @@ checksum = "c3eaad34cdd97d81de97964fc7f29e2d104f483840d906ef56daa1912338460b" dependencies = [ "cc", "libc", + "openssl-src", "pkg-config", "vcpkg", ] @@ -1918,10 +1941,12 @@ dependencies = [ "http-body 0.4.6", "hyper 0.14.28", "hyper-rustls", + "hyper-tls", "ipnet", "js-sys", "log", "mime", + "native-tls", "once_cell", "percent-encoding", "pin-project-lite", @@ -1932,11 +1957,14 @@ dependencies = [ "serde_urlencoded", "system-configuration", "tokio", + "tokio-native-tls", "tokio-rustls 0.24.1", + "tokio-util", "tower-service", "url", "wasm-bindgen", "wasm-bindgen-futures", + "wasm-streams", "web-sys", "webpki-roots", "winreg", @@ -2037,9 +2065,9 @@ dependencies = [ [[package]] name = "rustls-pki-types" -version = "1.0.1" +version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7673e0aa20ee4937c6aacfc12bb8341cfbf054cdd21df6bec5fd0629fe9339b" +checksum = "9e9d979b3ce68192e42760c7810125eb6cf2ea10efae545a156063e61f314e2a" [[package]] name = "rustls-webpki" @@ -2778,6 +2806,8 @@ dependencies = [ "dashmap", "faststr", "futures", + "hyper 1.1.0", + "hyper-util", "lazy_static", "libc", "metainfo", @@ -2787,6 +2817,7 @@ dependencies = [ "once_cell", "pin-project", "rand", + "reqwest", "rustls 0.22.1", "socket2", "thiserror", @@ -2940,6 +2971,7 @@ dependencies = [ "chrono", "futures", "fxhash", + "http 1.0.0", "lazy_static", "linked-hash-map", "linkedbytes", @@ -3047,6 +3079,19 @@ version = "0.2.89" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7ab9b36309365056cd639da3134bf87fa8f3d86008abf99e612384a6eecd459f" +[[package]] +name = "wasm-streams" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4609d447824375f43e1ffbc051b50ad8f4b3ae8219680c94452ea05eb240ac7" +dependencies = [ + "futures-util", + "js-sys", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", +] + [[package]] name = "web-sys" version = "0.3.66" diff --git a/Cargo.toml b/Cargo.toml index cffe3301..3546f37f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,7 +31,7 @@ pilota-thrift-parser = "0.10" motore = "0.4" # motore = { git = "https://github.com/cloudwego/motore", branch = "main" } -metainfo = "0.7" +metainfo = "0.7.7" anyhow = "1" async-broadcast = "0.6" @@ -62,6 +62,7 @@ http-body-util = "0.1" hyper = "1" hyper-timeout = "0.5" hyper-util = "0.1.2" +reqwest = { version = "0.11", features = ["json", "stream"] } itertools = "0" lazy_static = "1" libc = "0.2" diff --git a/volo-grpc/src/transport/client.rs b/volo-grpc/src/transport/client.rs index cc3d7c85..7b0856d3 100644 --- a/volo-grpc/src/transport/client.rs +++ b/volo-grpc/src/transport/client.rs @@ -47,6 +47,7 @@ impl ClientTransport { rpc_config.connect_timeout, rpc_config.read_timeout, rpc_config.write_timeout, + None, ); let mut http_client = http2::Builder::new(TokioExecutor::new()); http_client @@ -79,6 +80,7 @@ impl ClientTransport { rpc_config.connect_timeout, rpc_config.read_timeout, rpc_config.write_timeout, + None, ); let mut http_client = http2::Builder::new(TokioExecutor::new()); http_client @@ -231,6 +233,8 @@ fn build_uri(addr: Address, path: &str) -> hyper::Uri { .path_and_query(path) .build() .expect("fail to build unix uri"), + Address::Http(url) => hyper::Uri::try_from(url.to_string()) + .expect("fail to build http uri"), } } diff --git a/volo-grpc/src/transport/connect.rs b/volo-grpc/src/transport/connect.rs index 096a5398..ae408b81 100644 --- a/volo-grpc/src/transport/connect.rs +++ b/volo-grpc/src/transport/connect.rs @@ -35,7 +35,7 @@ impl Connector { #[cfg(any(feature = "rustls", feature = "native-tls"))] pub fn new_with_tls(cfg: Option, tls_config: ClientTlsConfig) -> Self { - let mut mt = TlsMakeTransport::new(cfg.unwrap_or_default(), tls_config); + let mut mt = TlsMakeTransport::new(cfg.clone().unwrap_or_default(), tls_config); if let Some(cfg) = cfg { mt.set_connect_timeout(cfg.connect_timeout); mt.set_read_timeout(cfg.read_timeout); diff --git a/volo-thrift/Cargo.toml b/volo-thrift/Cargo.toml index c32f08c3..9aef0eba 100644 --- a/volo-thrift/Cargo.toml +++ b/volo-thrift/Cargo.toml @@ -28,6 +28,7 @@ bytes.workspace = true chrono.workspace = true futures.workspace = true fxhash.workspace = true +http.workspace = true lazy_static.workspace = true linkedbytes.workspace = true linked-hash-map.workspace = true diff --git a/volo-thrift/src/client/mod.rs b/volo-thrift/src/client/mod.rs index c8d80a1d..2b1c4680 100644 --- a/volo-thrift/src/client/mod.rs +++ b/volo-thrift/src/client/mod.rs @@ -11,6 +11,7 @@ use std::{ sync::{atomic::AtomicI32, Arc}, }; +use http::{HeaderMap, header::IntoHeaderName, HeaderValue}; use motore::{ layer::{Identity, Layer, Stack}, service::{BoxCloneService, Service}, @@ -53,6 +54,7 @@ pub struct ClientBuilder { callee_name: FastStr, caller_name: FastStr, address: Option
, // maybe address use Arc avoid memory alloc + headers: Option, inner_layer: IL, outer_layer: OL, make_transport: MkT, @@ -75,6 +77,7 @@ impl Req, Resp, DefaultMakeTransport, + DefaultMakeCodec>>, LbConfig::Key>, DummyDiscover>, > @@ -86,10 +89,12 @@ impl caller_name: "".into(), callee_name: FastStr::new(service_name), address: None, + headers: None, inner_layer: Identity::new(), outer_layer: Identity::new(), mk_client: service_client, make_transport: DefaultMakeTransport::default(), + make_codec: DefaultMakeCodec::default(), mk_lb: LbConfig::new(WeightedRandomBalance::new(), DummyDiscover {}), _marker: PhantomData, @@ -115,6 +120,7 @@ impl caller_name: self.caller_name, callee_name: self.callee_name, address: self.address, + headers: self.headers, inner_layer: self.inner_layer, outer_layer: self.outer_layer, mk_client: self.mk_client, @@ -140,6 +146,7 @@ impl caller_name: self.caller_name, callee_name: self.callee_name, address: self.address, + headers: self.headers, inner_layer: self.inner_layer, outer_layer: self.outer_layer, mk_client: self.mk_client, @@ -214,6 +221,7 @@ impl ClientBuilder ClientBuilder ClientBuilder ClientBuilder(mut self, key: K, value: HeaderValue) -> Self { + if let Some(existing) = &mut self.headers { + existing.append(key, value); + } else { + let mut headers = HeaderMap::new(); + headers.append(key, value); + self.headers = Some(headers); + } + self + } + + /// Add transport headers + /// + pub fn headers(mut self, headers: HeaderMap) -> Self { + if let Some(existing) = &mut self.headers { + existing.extend(headers); + } else { + self.headers = Some(headers); + } + self + } + /// Adds a new inner layer to the client. /// /// The layer's `Service` should be `Send + Sync + Clone + 'static`. @@ -322,6 +356,7 @@ impl ClientBuilder ClientBuilder ClientBuilder ClientBuilder), #[cfg(feature = "native-tls")] NativeTls(#[pin] tokio_native_tls::TlsStream), + Http(#[pin] Http), } cfg_rustls! { @@ -51,6 +53,7 @@ pub enum OwnedWriteHalf { Rustls(#[pin] RustlsWriteHalf), #[cfg(feature = "native-tls")] NativeTls(#[pin] NativeTlsWriteHalf), + Http(#[pin] HttpWriteHalf), } impl AsyncWrite for OwnedWriteHalf { @@ -68,6 +71,7 @@ impl AsyncWrite for OwnedWriteHalf { OwnedWriteHalfProj::Rustls(half) => half.poll_write(cx, buf), #[cfg(feature = "native-tls")] OwnedWriteHalfProj::NativeTls(half) => half.poll_write(cx, buf), + OwnedWriteHalfProj::Http(half) => half.poll_write(cx, buf), } } @@ -81,6 +85,7 @@ impl AsyncWrite for OwnedWriteHalf { OwnedWriteHalfProj::Rustls(half) => half.poll_flush(cx), #[cfg(feature = "native-tls")] OwnedWriteHalfProj::NativeTls(half) => half.poll_flush(cx), + OwnedWriteHalfProj::Http(half) => half.poll_flush(cx), } } @@ -94,6 +99,7 @@ impl AsyncWrite for OwnedWriteHalf { OwnedWriteHalfProj::Rustls(half) => half.poll_shutdown(cx), #[cfg(feature = "native-tls")] OwnedWriteHalfProj::NativeTls(half) => half.poll_shutdown(cx), + OwnedWriteHalfProj::Http(half) => half.poll_shutdown(cx), } } @@ -111,6 +117,7 @@ impl AsyncWrite for OwnedWriteHalf { OwnedWriteHalfProj::Rustls(half) => half.poll_write_vectored(cx, bufs), #[cfg(feature = "native-tls")] OwnedWriteHalfProj::NativeTls(half) => half.poll_write_vectored(cx, bufs), + OwnedWriteHalfProj::Http(half) => half.poll_write_vectored(cx, bufs), } } @@ -124,6 +131,7 @@ impl AsyncWrite for OwnedWriteHalf { Self::Rustls(half) => half.is_write_vectored(), #[cfg(feature = "native-tls")] Self::NativeTls(half) => half.is_write_vectored(), + Self::Http(half) => half.is_write_vectored(), } } } @@ -145,6 +153,7 @@ pub enum OwnedReadHalf { Rustls(#[pin] RustlsReadHalf), #[cfg(feature = "native-tls")] NativeTls(#[pin] NativeTlsReadHalf), + Http(#[pin] HttpReadHalf), } impl AsyncRead for OwnedReadHalf { @@ -162,6 +171,7 @@ impl AsyncRead for OwnedReadHalf { OwnedReadHalfProj::Rustls(half) => half.poll_read(cx, buf), #[cfg(feature = "native-tls")] OwnedReadHalfProj::NativeTls(half) => half.poll_read(cx, buf), + OwnedReadHalfProj::Http(half) => half.poll_read(cx, buf), } } } @@ -189,6 +199,10 @@ impl ConnStream { let (rh, wh) = tokio::io::split(stream); (OwnedReadHalf::NativeTls(rh), OwnedWriteHalf::NativeTls(wh)) } + Self::Http(stream) => { + let (rh, wh) = tokio::io::split(stream); + (OwnedReadHalf::Http(rh), OwnedWriteHalf::Http(wh)) + } } } } @@ -201,6 +215,13 @@ impl From for ConnStream { } } +impl From for ConnStream { + #[inline] + fn from(value: Http) -> Self { + Self::Http(value) + } +} + #[cfg(target_family = "unix")] impl From for ConnStream { #[inline] @@ -242,6 +263,7 @@ impl AsyncRead for ConnStream { IoStreamProj::Rustls(s) => s.poll_read(cx, buf), #[cfg(feature = "native-tls")] IoStreamProj::NativeTls(s) => s.poll_read(cx, buf), + IoStreamProj::Http(s) => s.poll_read(cx, buf), } } } @@ -261,6 +283,7 @@ impl AsyncWrite for ConnStream { IoStreamProj::Rustls(s) => s.poll_write(cx, buf), #[cfg(feature = "native-tls")] IoStreamProj::NativeTls(s) => s.poll_write(cx, buf), + IoStreamProj::Http(s) => s.poll_write(cx, buf), } } @@ -274,6 +297,7 @@ impl AsyncWrite for ConnStream { IoStreamProj::Rustls(s) => s.poll_flush(cx), #[cfg(feature = "native-tls")] IoStreamProj::NativeTls(s) => s.poll_flush(cx), + IoStreamProj::Http(s) => s.poll_flush(cx), } } @@ -287,6 +311,7 @@ impl AsyncWrite for ConnStream { IoStreamProj::Rustls(s) => s.poll_shutdown(cx), #[cfg(feature = "native-tls")] IoStreamProj::NativeTls(s) => s.poll_shutdown(cx), + IoStreamProj::Http(s) => s.poll_shutdown(cx), } } @@ -304,6 +329,7 @@ impl AsyncWrite for ConnStream { IoStreamProj::Rustls(s) => s.poll_write_vectored(cx, bufs), #[cfg(feature = "native-tls")] IoStreamProj::NativeTls(s) => s.poll_write_vectored(cx, bufs), + IoStreamProj::Http(s) => s.poll_write_vectored(cx, bufs), } } @@ -317,6 +343,7 @@ impl AsyncWrite for ConnStream { Self::Rustls(s) => s.is_write_vectored(), #[cfg(feature = "native-tls")] Self::NativeTls(s) => s.is_write_vectored(), + Self::Http(s) => s.is_write_vectored(), } } } @@ -338,9 +365,13 @@ impl ConnStream { .peer_addr() .map(Address::from) .ok(), + Self::Http(s) => { + // TOOD: should we return remote_addr? + s.meta.get_address() + } } } -} +} pub struct Conn { pub stream: ConnStream, pub info: ConnInfo, diff --git a/volo/src/net/dial.rs b/volo/src/net/dial.rs index efef8a9e..905a754e 100644 --- a/volo/src/net/dial.rs +++ b/volo/src/net/dial.rs @@ -1,5 +1,6 @@ use std::{future::Future, io, net::SocketAddr}; +use hyper::HeaderMap; use socket2::{Domain, Protocol, Socket, Type}; #[cfg(target_family = "unix")] use tokio::net::UnixStream; @@ -11,7 +12,7 @@ use tokio::{ use super::{ conn::{Conn, OwnedReadHalf, OwnedWriteHalf}, - Address, + Address, http::make_http_connection, }; /// [`MakeTransport`] creates an [`AsyncRead`] and an [`AsyncWrite`] for the given [`Address`]. @@ -26,18 +27,21 @@ pub trait MakeTransport: Clone + Send + Sync + 'static { fn set_connect_timeout(&mut self, timeout: Option); fn set_read_timeout(&mut self, timeout: Option); fn set_write_timeout(&mut self, timeout: Option); + + fn set_headers(&mut self, headers: Option); } -#[derive(Default, Debug, Clone, Copy)] +#[derive(Default, Debug, Clone)] pub struct DefaultMakeTransport { cfg: Config, } -#[derive(Default, Debug, Clone, Copy)] +#[derive(Default, Debug, Clone)] pub struct Config { pub connect_timeout: Option, pub read_timeout: Option, pub write_timeout: Option, + pub headers: Option, } impl Config { @@ -45,11 +49,13 @@ impl Config { connect_timeout: Option, read_timeout: Option, write_timeout: Option, + headers: Option, ) -> Self { Self { connect_timeout, read_timeout, write_timeout, + headers, } } @@ -67,6 +73,10 @@ impl Config { self.write_timeout = timeout; self } + pub fn with_headers(mut self, headers: Option) -> Self { + self.headers = headers; + self + } } impl DefaultMakeTransport { @@ -87,15 +97,19 @@ impl MakeTransport for DefaultMakeTransport { } fn set_connect_timeout(&mut self, timeout: Option) { - self.cfg = self.cfg.with_connect_timeout(timeout); + self.cfg = self.cfg.clone().with_connect_timeout(timeout); } fn set_read_timeout(&mut self, timeout: Option) { - self.cfg = self.cfg.with_read_timeout(timeout); + self.cfg = self.cfg.clone().with_read_timeout(timeout); } fn set_write_timeout(&mut self, timeout: Option) { - self.cfg = self.cfg.with_write_timeout(timeout); + self.cfg = self.cfg.clone().with_write_timeout(timeout); + } + + fn set_headers(&mut self, headers: Option) { + self.cfg = self.cfg.clone().with_headers(headers); } } @@ -136,6 +150,10 @@ impl DefaultMakeTransport { } #[cfg(target_family = "unix")] Address::Unix(addr) => UnixStream::connect(addr).await.map(Conn::from), + Address::Http(url) => { + let stream = make_http_connection(&self.cfg, url).await?; + Ok(Conn::from(stream)) + } } } } @@ -204,8 +222,8 @@ cfg_rustls_or_native_tls! { match &self.tls_config.connector { #[cfg(feature = "rustls")] TlsConnector::Rustls(connector) => { - let server_name = librustls::ServerName::try_from(&self.tls_config.server_name[..]) - .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?; + let server_name = librustls::pki_types::ServerName::try_from(&self.tls_config.server_name[..]) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?.to_owned(); connector .connect(server_name, tcp) .await @@ -225,6 +243,8 @@ cfg_rustls_or_native_tls! { } #[cfg(target_family = "unix")] Address::Unix(addr) => UnixStream::connect(addr).await.map(Conn::from), + + Address::Http(url) => todo!(), } } } @@ -241,15 +261,19 @@ cfg_rustls_or_native_tls! { } fn set_connect_timeout(&mut self, timeout: Option) { - self.cfg = self.cfg.with_connect_timeout(timeout); + self.cfg = self.cfg.clone().with_connect_timeout(timeout); } fn set_read_timeout(&mut self, timeout: Option) { - self.cfg = self.cfg.with_read_timeout(timeout); + self.cfg = self.cfg.clone().with_read_timeout(timeout); } fn set_write_timeout(&mut self, timeout: Option) { - self.cfg = self.cfg.with_write_timeout(timeout); + self.cfg = self.cfg.clone().with_write_timeout(timeout); + } + + fn set_headers(&mut self, headers: Option) { + todo!() } } } diff --git a/volo/src/net/http.rs b/volo/src/net/http.rs new file mode 100644 index 00000000..71ae45b6 --- /dev/null +++ b/volo/src/net/http.rs @@ -0,0 +1,314 @@ +use std::{sync::{Arc, Mutex}, time::Duration, io, collections::HashMap}; + + +use futures::Future; +use hyper::HeaderMap; +use pin_project::pin_project; +use tokio::io::{DuplexStream, AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt}; + +use super::{Address, dial::{MakeTransport, Config}}; + +const WINDOW_SIZE: usize = 0x4000; // 4 * 4 + +struct WaitFlushed { + inner: Arc>, +} + +impl Future for WaitFlushed { + type Output = (); + + fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll { + cx.waker().wake_by_ref(); + let val = self.inner.lock().unwrap(); + // println!("fut check {val}"); + if !*val { + std::task::Poll::Pending + } else { + std::task::Poll::Ready(()) + } + } +} + +#[derive(Clone)] +pub struct HttpMeta { + is_flushed: Arc>, + address: Option, +} + +impl Default for HttpMeta { + fn default() -> Self { + Self { + address: None, + is_flushed: Arc::new(Mutex::new(false)), + } + } +} + +impl HttpMeta { + pub fn with_address(mut self, addr: Address) -> Self { + if let Address::Http(url) = addr { + self.address = Some(url); + } + self + } + + pub fn get_url(&self) -> Option { + self.address.clone() + } + + pub fn get_address(&self) -> Option
{ + self.address.clone() + .map(|addr| + Address::Http(addr)) + } + + pub fn wait_flushed(&self) -> impl Future { + WaitFlushed { + inner: self.is_flushed.clone(), + } + } + pub fn reset_flushed(&self) { + *self.is_flushed.lock().unwrap() = false; + } +} + +#[pin_project] +pub struct HttpStream { + #[pin] + reader: R, + #[pin] + writer: W, + + pub meta: HttpMeta, +} + +type HttpReadHalfInternal = tokio::io::ReadHalf; +type HttpWriteHalfInternal = tokio::io::WriteHalf; +pub type Http = HttpStream; + +pub type HttpReadHalf = tokio::io::ReadHalf; +pub type HttpWriteHalf = tokio::io::WriteHalf; + +impl Http +{ + pub fn new(stream: DuplexStream) -> Self { + let (rd, wr) = tokio::io::split(stream); + Self { + reader: rd, + writer: wr, + meta: HttpMeta::default(), + } + } + + // pub fn get_meta(&self) -> HttpMeta { + // self.meta.clone() + // } + + pub fn with_meta(mut self, meta: HttpMeta) -> Self { + self.meta = meta; + self + } +} + +impl AsyncRead for Http +{ + fn poll_read( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> std::task::Poll> { + self.project().reader.poll_read(cx, buf) + } +} + +impl AsyncWrite for Http +{ + fn poll_write( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> std::task::Poll> { + self.project().writer.poll_write(cx, buf) + } + + fn poll_flush(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll> { + let this = self.project(); + let result = this.writer.poll_flush(cx); + if result.is_ready() { + *this.meta.is_flushed.lock().unwrap() = true; + } + result + } + + fn poll_shutdown(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll> { + self.project().writer.poll_shutdown(cx) + } +} + +#[derive(Clone)] +pub struct HttpTransport { + headers: Option, + + connect_timeout: Option, + read_timeout: Option, +} + +impl Default for HttpTransport { + fn default() -> Self { + Self { + connect_timeout: None, + read_timeout: None, + headers: None, + } + } +} + +impl HttpTransport { + fn new(cfg: &Config) -> Self { + Self { + connect_timeout: cfg.connect_timeout, + read_timeout: cfg.read_timeout, + headers: cfg.headers.clone(), + ..Default::default() + } + } + + fn builder(&self) -> reqwest::ClientBuilder { + reqwest::Client::builder() + } + fn build_client(&self) -> reqwest::Client { + let mut builder = self.builder(); + if let Some(timeout) = self.connect_timeout { + builder = builder.connect_timeout(timeout); + } + if let Some(timeout) = self.read_timeout { + builder = builder.timeout(timeout) + } + builder.build().unwrap() + } + fn get_headers(&self) -> Option { + self.headers.clone() + } + + async fn make_transport_conn( + &self, + addr: Address, + ) -> std::io::Result { + let (cs, mut sc) = tokio::io::duplex(WINDOW_SIZE); + let cs = Http::new(cs) + .with_meta(HttpMeta::default() + .with_address(addr)); + + let meta = cs.meta.clone(); + let client = self.build_client(); + let headers = self.get_headers(); + + tokio::spawn(async move { + let url = meta.get_url().unwrap(); + // let mut i = 0; + loop { + // println!("reqnr {i}"); i += 1; + let mut payload = Vec::with_capacity(WINDOW_SIZE); + meta.wait_flushed().await; + meta.reset_flushed(); + match sc.read_buf(&mut payload).await { + Ok(siz) => { + if siz == 0 { + eprintln!("got transport error EOF"); + return; + } + let mut req = client.post(url.to_string()); + if let Some(headers) = &headers { + let headers = headers + .iter() + .map(|(key, val)| (key.to_string(), val.to_str().unwrap().to_string())) + .collect::>(); + + let headers: reqwest::header::HeaderMap = (&headers) + .try_into() + .expect("valid headers"); + req = req.headers(headers); + } + // println!("rpc_payload = {:?}", payload); + // println!("headers = {:?}", headers); + let req = req.body(reqwest::Body::from(payload)) + .build() + .unwrap() + ; + let resp = client.execute(req).await; + match resp { + Ok(mut resp) => { + if resp.status() != reqwest::StatusCode::OK { + eprintln!("got transport response not OK"); + return; + } + // check for content-type? + + while let Ok(Some(chunk)) = resp.chunk().await { + if let Err(e) = sc.write(&chunk).await { + eprintln!("got transport error response download {:#?}", e); + return; + } + } + sc.flush().await.unwrap(); + } + Err(e) => { + eprintln!("got transport error transmit {:#?}", e); + return; + } + } + } + Err(e) => { + eprintln!("got transport error {:#?}", e); + return; + } + } + } + }); + + Ok(cs) + } +} + +impl MakeTransport for HttpTransport { + type ReadHalf = HttpReadHalf; + type WriteHalf = HttpWriteHalf; + + fn set_connect_timeout(&mut self, timeout: Option) { + self.connect_timeout = timeout; + } + + fn set_read_timeout(&mut self, timeout: Option) { + self.read_timeout = timeout; + } + + // set_write_timeout is no-op, write should be buffered. + fn set_write_timeout(&mut self, _timeout: Option) { + } + + fn set_headers(&mut self, headers: Option) { + self.headers = headers; + } + + async fn make_transport( + &self, + addr: Address, + ) -> std::io::Result<(Self::ReadHalf, Self::WriteHalf)>{ + let cs = self.make_transport_conn(addr).await?; + let (csr, csw) = tokio::io::split(cs); + Ok((csr, csw)) + } +} + +// impl MakeIncoming for HttpTransport { +// type Incoming = (); +// async fn make_incoming(self) -> io::Result { +// todo!() +// } +// } + +pub async fn make_http_connection(cfg: &Config, url: hyper::Uri) -> Result { + let trans = HttpTransport::new(cfg); + let cs = trans.make_transport_conn(Address::Http(url)).await?; + Ok(cs) +} \ No newline at end of file diff --git a/volo/src/net/incoming.rs b/volo/src/net/incoming.rs index e1750167..8f52e345 100644 --- a/volo/src/net/incoming.rs +++ b/volo/src/net/incoming.rs @@ -80,6 +80,7 @@ impl MakeIncoming for Address { let listener = unix_helper::create_unix_listener_with_max_backlog(addr).await; UnixListener::from_std(listener?).map(DefaultIncoming::from) } + Address::Http(_url) => todo!("not implemented yet"), } } } @@ -93,6 +94,7 @@ impl MakeIncoming for Address { Address::Ip(addr) => TcpListener::bind(addr).await.map(DefaultIncoming::from), #[cfg(target_family = "unix")] Address::Unix(addr) => UnixListener::bind(addr).map(DefaultIncoming::from), + Address::Http(_url) => todo!("not implemented yet") } } } diff --git a/volo/src/net/mod.rs b/volo/src/net/mod.rs index 8620e9e2..78ef0e24 100644 --- a/volo/src/net/mod.rs +++ b/volo/src/net/mod.rs @@ -1,12 +1,14 @@ pub mod conn; pub mod dial; pub mod incoming; +mod http; mod probe; #[cfg(target_family = "unix")] use std::{borrow::Cow, path::Path}; use std::{fmt, net::Ipv6Addr}; + pub use incoming::{DefaultIncoming, MakeIncoming}; #[derive(Clone, Debug, PartialEq, Eq, Hash)] @@ -14,6 +16,7 @@ pub enum Address { Ip(std::net::SocketAddr), #[cfg(target_family = "unix")] Unix(Cow<'static, Path>), + Http(hyper::Uri) } impl Address { @@ -26,6 +29,7 @@ impl Address { self } } + Address::Http(_) => self, #[cfg(target_family = "unix")] _ => self, } @@ -43,10 +47,17 @@ impl fmt::Display for Address { Address::Ip(addr) => write!(f, "{addr}"), #[cfg(target_family = "unix")] Address::Unix(path) => write!(f, "{}", path.display()), + Address::Http(url) => write!(f, "{url}"), } } } +impl From for Address { + fn from(url: hyper::Uri) -> Self { + Address::Http(url) + } +} + impl From for Address { fn from(addr: std::net::SocketAddr) -> Self { Address::Ip(addr)