1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
|
//! Run with `cargo run --all-features --example rustls_session` command.
//!
//! To connect through browser, navigate to "https://localhost:3000" url.
use axum::{middleware::AddExtension, routing::get, Extension, Router};
use axum_server::{
accept::Accept,
tls_rustls::{RustlsAcceptor, RustlsConfig},
};
use futures_util::future::BoxFuture;
use std::{io, net::SocketAddr, sync::Arc};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_rustls::server::TlsStream;
use tower::Layer;
#[tokio::main]
async fn main() {
let app = Router::new().route("/", get(handler));
let config = RustlsConfig::from_pem_file(
"examples/self-signed-certs/cert.pem",
"examples/self-signed-certs/key.pem",
)
.await
.unwrap();
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
println!("listening on {}", addr);
let acceptor = CustomAcceptor::new(RustlsAcceptor::new(config));
let server = axum_server::bind(addr).acceptor(acceptor);
server.serve(app.into_make_service()).await.unwrap();
}
async fn handler(tls_data: Extension<TlsData>) -> String {
format!("{:?}", tls_data)
}
#[derive(Debug, Clone)]
struct TlsData {
_hostname: Option<Arc<str>>,
}
#[derive(Debug, Clone)]
struct CustomAcceptor {
inner: RustlsAcceptor,
}
impl CustomAcceptor {
fn new(inner: RustlsAcceptor) -> Self {
Self { inner }
}
}
impl<I, S> Accept<I, S> for CustomAcceptor
where
I: AsyncRead + AsyncWrite + Unpin + Send + 'static,
S: Send + 'static,
{
type Stream = TlsStream<I>;
type Service = AddExtension<S, TlsData>;
type Future = BoxFuture<'static, io::Result<(Self::Stream, Self::Service)>>;
fn accept(&self, stream: I, service: S) -> Self::Future {
let acceptor = self.inner.clone();
Box::pin(async move {
let (stream, service) = acceptor.accept(stream, service).await?;
let server_conn = stream.get_ref().1;
let sni_hostname = TlsData {
_hostname: server_conn.server_name().map(From::from),
};
let service = Extension(sni_hostname).layer(service);
Ok((stream, service))
})
}
}
|