From 00c5c2bd63e2fda63c5c5d2f981e4450871ecc82 Mon Sep 17 00:00:00 2001 From: Romain Mallard Date: Fri, 26 Sep 2025 02:23:25 +0200 Subject: [PATCH] simple websocket implementation (without auth use) --- Cargo.lock | 76 +++++++++++++++++++++++++++++ Cargo.toml | 3 +- src/main.rs | 15 ++++++ src/utils/db_pool.rs | 12 +++++ src/utils/mod.rs | 3 +- src/utils/routes.rs | 4 +- src/utils/websocket.rs | 108 +++++++++++++++++++++++++++++++++++++++++ 7 files changed, 218 insertions(+), 3 deletions(-) create mode 100644 src/utils/websocket.rs diff --git a/Cargo.lock b/Cargo.lock index 0bdf116..26490ad 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -63,6 +63,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "021e862c184ae977658b36c4500f7feac3221ca5da43e3f25bd04ab6c79a29b5" dependencies = [ "axum-core", + "base64", "bytes", "form_urlencoded", "futures-util", @@ -82,8 +83,10 @@ dependencies = [ "serde_json", "serde_path_to_error", "serde_urlencoded", + "sha1", "sync_wrapper", "tokio", + "tokio-tungstenite", "tower", "tower-layer", "tower-service", @@ -286,6 +289,12 @@ dependencies = [ "parking_lot_core", ] +[[package]] +name = "data-encoding" +version = "2.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a2330da5de22e8a3cb63252ce2abb30116bf5265e89c0e01bc17015ce30a476" + [[package]] name = "deranged" version = "0.5.3" @@ -366,6 +375,23 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" +[[package]] +name = "futures-macro" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "futures-sink" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e575fab7d1e0dcb8d0c7bcf9a63ee213816ab51902e6d244a95819acacf1d4f7" + [[package]] name = "futures-task" version = "0.3.31" @@ -379,9 +405,12 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" dependencies = [ "futures-core", + "futures-macro", + "futures-sink", "futures-task", "pin-project-lite", "pin-utils", + "slab", ] [[package]] @@ -460,6 +489,7 @@ dependencies = [ "chrono", "dashmap", "dotenvy", + "futures-util", "jsonwebtoken", "r2d2", "r2d2_sqlite", @@ -1050,6 +1080,17 @@ dependencies = [ "serde", ] +[[package]] +name = "sha1" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "shlex" version = "1.3.0" @@ -1204,6 +1245,18 @@ dependencies = [ "syn", ] +[[package]] +name = "tokio-tungstenite" +version = "0.26.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a9daff607c6d2bf6c16fd681ccb7eecc83e4e2cdc1ca067ffaadfca5de7f084" +dependencies = [ + "futures-util", + "log", + "tokio", + "tungstenite", +] + [[package]] name = "tower" version = "0.5.2" @@ -1252,6 +1305,23 @@ dependencies = [ "once_cell", ] +[[package]] +name = "tungstenite" +version = "0.26.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4793cb5e56680ecbb1d843515b23b6de9a75eb04b66643e256a396d43be33c13" +dependencies = [ + "bytes", + "data-encoding", + "http", + "httparse", + "log", + "rand", + "sha1", + "thiserror", + "utf-8", +] + [[package]] name = "typenum" version = "1.18.0" @@ -1270,6 +1340,12 @@ version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" +[[package]] +name = "utf-8" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" + [[package]] name = "uuid" version = "1.18.1" diff --git a/Cargo.toml b/Cargo.toml index 3e51451..866cadf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,7 +5,7 @@ edition = "2024" publish = false [dependencies] -axum = "0.8.4" +axum = {version = "0.8.4", features = ["ws"]} tokio = { version = "1", features = ["full"] } serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" @@ -23,6 +23,7 @@ chrono = "0.4.42" anyhow = "1.0.100" argon2 = {version = "0.5.3"} rand_core = {version = "0.6.4", features = ["getrandom"]} +futures-util = {version = "0.3.31"} diff --git a/src/main.rs b/src/main.rs index 46b8c6a..3344437 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,7 +1,9 @@ use axum::serve; use axum::Extension; +use axum::extract::{ws::{Message, WebSocket, WebSocketUpgrade}, State}; use jsonwebtoken::{DecodingKey, EncodingKey}; use tokio::net::TcpListener; +use tokio::sync::mpsc; mod utils; mod routes; @@ -9,6 +11,8 @@ mod rooms; mod chat; use r2d2::{Pool}; use r2d2_sqlite::SqliteConnectionManager; +use dashmap::DashMap; +use std::sync::Arc; use crate::utils::db_pool::{HotelPool,AppState}; use routes::create_router; @@ -16,6 +20,7 @@ use crate::utils::auth::JwtKeys; + #[tokio::main] async fn main() -> std::io::Result<()> { @@ -26,9 +31,17 @@ async fn main() -> std::io::Result<()> { .build(logs_manager) .expect("Failed to build logs pool"); + type UserMap = DashMap>; + /// hotel_id → users + type HotelMap = DashMap>; + /// global map of all hotels + type WsMap = Arc; + + let state = AppState { hotel_pools, logs_pool, + ws_map: Arc::new(DashMap::new()), //jwt_secret: "your_jwt_secret_key s".to_string(), // better: load from env var }; @@ -38,6 +51,8 @@ async fn main() -> std::io::Result<()> { decoding: DecodingKey::from_secret(jwt_secret.as_ref()), }; + + let app = create_router(state) .layer(Extension(jwt_keys)); diff --git a/src/utils/db_pool.rs b/src/utils/db_pool.rs index 7d1d5f8..2181cd5 100644 --- a/src/utils/db_pool.rs +++ b/src/utils/db_pool.rs @@ -2,13 +2,25 @@ use std::sync::Arc; use dashmap::DashMap; use r2d2::{Pool}; use r2d2_sqlite::SqliteConnectionManager; +use tokio::sync::mpsc; +use axum::extract::{ws::{Message, WebSocket, WebSocketUpgrade}, State}; type HotelId = i32; // or i32 if you want numeric ids + /// Type alias: user_id → sender to that user +type UserMap = DashMap>; +/// hotel_id → users +type HotelMap = DashMap>; +/// global map of all hotels +type WsMap = Arc; + /// Type alias: user_id → sender to that user + + #[derive(Clone)] pub struct AppState { pub hotel_pools: HotelPool, pub logs_pool: Pool, + pub ws_map: WsMap, } diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 1f4231e..3ec42dc 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -1,3 +1,4 @@ pub mod db_pool; pub mod auth; -pub mod routes; \ No newline at end of file +pub mod routes; +pub mod websocket; \ No newline at end of file diff --git a/src/utils/routes.rs b/src/utils/routes.rs index 6974982..ea118a0 100644 --- a/src/utils/routes.rs +++ b/src/utils/routes.rs @@ -5,7 +5,7 @@ use axum::{ use crate::utils::auth::*; use crate::utils::db_pool::{HotelPool, AppState, }; - +use crate::utils::websocket::ws_handler; // ROOTS pub fn utils_routes() -> Router { @@ -13,6 +13,8 @@ pub fn utils_routes() -> Router { Router::new() .route("/login", put(clean_auth_loging)) .route("/register", put(register_user)) + .route("/ws/{hotel_id}/{user_id}", get(ws_handler)) .route("/tokentest", put(token_tester)) + //.with_state(state) } \ No newline at end of file diff --git a/src/utils/websocket.rs b/src/utils/websocket.rs new file mode 100644 index 0000000..b9117b0 --- /dev/null +++ b/src/utils/websocket.rs @@ -0,0 +1,108 @@ +use dashmap::DashMap; +use std::sync::Arc; +use axum::extract::{ws::{Message, WebSocket, WebSocketUpgrade}, State}; +use tokio::sync::mpsc; +use axum::extract::Path; +use axum::response::IntoResponse; +//use futures_util::stream::stream::StreamExt; +use futures_util::{StreamExt, SinkExt}; + +use crate::utils::db_pool::{HotelPool,AppState}; + + + + /// Type alias: user_id → sender to that user +type UserMap = DashMap>; +/// hotel_id → users +type HotelMap = DashMap>; +/// global map of all hotels +type WsMap = Arc; + /// Type alias: user_id → sender to that user + + +async fn handle_socket( + mut socket: WebSocket, + state: AppState, + hotel_id: i32, + user_id: i32, +) { + // channel for sending messages TO this client + let (tx, mut rx) = mpsc::unbounded_channel::(); + + // insert into hotel → user map + let user_map = state + .ws_map + .entry(hotel_id) + .or_insert_with(|| Arc::new(UserMap::new())) + .clone(); + + user_map.insert(user_id, tx); + + // ✅ print after upgrading + print_ws_state(&state); + + // split socket into sender/receiver + let (mut sender, mut receiver) = socket.split(); + + + + // task for sending messages from server to client + let mut rx_task = tokio::spawn(async move { + while let Some(msg) = rx.recv().await { + if sender.send(msg).await.is_err() { + break; + } + } + }); + + // task for receiving messages from client + let state_clone = state.clone(); + let mut recv_task = tokio::spawn(async move { + while let Some(Ok(msg)) = receiver.next().await { + match msg { + Message::Text(text) => { + println!("Hotel {hotel_id}, User {user_id} said: {text}"); + // echo back just as an example + if let Some(hotel_entry) = state_clone.ws_map.get(&hotel_id) { + if let Some(sender) = hotel_entry.get(&user_id) { + let _ = sender.send(Message::Text(format!("echo: {text}").into())); + } + } + } + Message::Close(_) => break, + _ => {} + } + } + }); + + // wait for either side to finish + tokio::select! { + _ = (&mut rx_task) => recv_task.abort(), + _ = (&mut recv_task) => rx_task.abort(), + } + + // cleanup + user_map.remove(&user_id); + if user_map.is_empty() { + state.ws_map.remove(&hotel_id); + } +} + +pub async fn ws_handler( + ws: WebSocketUpgrade, + State(state): State, + Path((hotel_id, user_id)): Path<(i32, i32)>, +) -> impl IntoResponse { + ws.on_upgrade(move |socket| handle_socket(socket, state, hotel_id, user_id)) +} + +fn print_ws_state(state: &AppState) { + println!("--- Current WebSocket state ---"); + for hotel_entry in state.ws_map.iter() { + let hotel_id = *hotel_entry.key(); + let user_map = hotel_entry.value(); + let users: Vec<_> = user_map.iter().map(|u| *u.key()).collect(); + println!("Hotel {hotel_id}: users {:?}", users); + } + println!("--------------------------------"); +} \ No newline at end of file