1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200
|
//! A chat server that broadcasts a message to all connections.
//!
//! This is a simple line-based server which accepts WebSocket connections,
//! reads lines from those connections, and broadcasts the lines to all other
//! connected clients.
//!
//! You can test this out by running:
//!
//! cargo run --example server 127.0.0.1:12345
//!
//! And then in another window run:
//!
//! cargo run --example client ws://127.0.0.1:12345/socket
//!
//! You can run the second command in multiple windows and then chat between the
//! two, seeing the messages from the other client as they're received. For all
//! connected clients they'll all join the same room and see everyone else's
//! messages.
use std::{
collections::HashMap,
convert::Infallible,
env,
net::SocketAddr,
sync::{Arc, Mutex},
};
use hyper::{
body::Incoming,
header::{
HeaderValue, CONNECTION, SEC_WEBSOCKET_ACCEPT, SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_VERSION,
UPGRADE,
},
server::conn::http1,
service::service_fn,
upgrade::Upgraded,
Method, Request, Response, StatusCode, Version,
};
use hyper_util::rt::TokioIo;
use tokio::net::TcpListener;
use futures_channel::mpsc::{unbounded, UnboundedSender};
use futures_util::{future, pin_mut, stream::TryStreamExt, StreamExt};
use async_tungstenite::{tokio::TokioAdapter, WebSocketStream};
use tungstenite::{
handshake::derive_accept_key,
protocol::{Message, Role},
};
type Tx = UnboundedSender<Message>;
type PeerMap = Arc<Mutex<HashMap<SocketAddr, Tx>>>;
type Body = http_body_util::Full<hyper::body::Bytes>;
async fn handle_connection(
peer_map: PeerMap,
ws_stream: WebSocketStream<TokioAdapter<TokioIo<Upgraded>>>,
addr: SocketAddr,
) {
println!("WebSocket connection established: {}", addr);
// Insert the write part of this peer to the peer map.
let (tx, rx) = unbounded();
peer_map.lock().unwrap().insert(addr, tx);
let (outgoing, incoming) = ws_stream.split();
let broadcast_incoming = incoming.try_for_each(|msg| {
println!(
"Received a message from {}: {}",
addr,
msg.to_text().unwrap()
);
let peers = peer_map.lock().unwrap();
// We want to broadcast the message to everyone except ourselves.
let broadcast_recipients = peers
.iter()
.filter(|(peer_addr, _)| peer_addr != &&addr)
.map(|(_, ws_sink)| ws_sink);
for recp in broadcast_recipients {
recp.unbounded_send(msg.clone()).unwrap();
}
future::ok(())
});
let receive_from_others = rx.map(Ok).forward(outgoing);
pin_mut!(broadcast_incoming, receive_from_others);
future::select(broadcast_incoming, receive_from_others).await;
println!("{} disconnected", &addr);
peer_map.lock().unwrap().remove(&addr);
}
async fn handle_request(
peer_map: PeerMap,
mut req: Request<Incoming>,
addr: SocketAddr,
) -> Result<Response<Body>, Infallible> {
println!("Received a new, potentially ws handshake");
println!("The request's path is: {}", req.uri().path());
println!("The request's headers are:");
for (ref header, _value) in req.headers() {
println!("* {}", header);
}
let upgrade = HeaderValue::from_static("Upgrade");
let websocket = HeaderValue::from_static("websocket");
let headers = req.headers();
let key = headers.get(SEC_WEBSOCKET_KEY);
let derived = key.map(|k| derive_accept_key(k.as_bytes()));
if req.method() != Method::GET
|| req.version() < Version::HTTP_11
|| !headers
.get(CONNECTION)
.and_then(|h| h.to_str().ok())
.map(|h| {
h.split(|c| c == ' ' || c == ',')
.any(|p| p.eq_ignore_ascii_case(upgrade.to_str().unwrap()))
})
.unwrap_or(false)
|| !headers
.get(UPGRADE)
.and_then(|h| h.to_str().ok())
.map(|h| h.eq_ignore_ascii_case("websocket"))
.unwrap_or(false)
|| !headers
.get(SEC_WEBSOCKET_VERSION)
.map(|h| h == "13")
.unwrap_or(false)
|| key.is_none()
|| req.uri() != "/socket"
{
return Ok(Response::new(Body::from("Hello World!")));
}
let ver = req.version();
tokio::task::spawn(async move {
match hyper::upgrade::on(&mut req).await {
Ok(upgraded) => {
handle_connection(
peer_map,
WebSocketStream::from_raw_socket(
TokioAdapter::new(TokioIo::new(upgraded)),
Role::Server,
None,
)
.await,
addr,
)
.await;
}
Err(e) => println!("upgrade error: {}", e),
}
});
let mut res = Response::new(Body::default());
*res.status_mut() = StatusCode::SWITCHING_PROTOCOLS;
*res.version_mut() = ver;
res.headers_mut().append(CONNECTION, upgrade);
res.headers_mut().append(UPGRADE, websocket);
res.headers_mut()
.append(SEC_WEBSOCKET_ACCEPT, derived.unwrap().parse().unwrap());
// Let's add an additional header to our response to the client.
res.headers_mut()
.append("MyCustomHeader", ":)".parse().unwrap());
res.headers_mut()
.append("SOME_TUNGSTENITE_HEADER", "header_value".parse().unwrap());
Ok(res)
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let state = PeerMap::new(Mutex::new(HashMap::new()));
let addr = env::args()
.nth(1)
.unwrap_or_else(|| "127.0.0.1:8080".to_string())
.parse::<SocketAddr>()?;
let listener = TcpListener::bind(addr).await?;
loop {
let (stream, addr) = listener.accept().await?;
let state = state.clone();
tokio::spawn(async move {
let io = TokioIo::new(stream);
let service = service_fn(move |req| handle_request(state.clone(), req, addr));
let conn = http1::Builder::new()
.serve_connection(io, service)
.with_upgrades();
if let Err(err) = conn.await {
eprintln!("failed to serve connection: {err:?}");
}
});
}
}
|