simple websocket implementation (without auth use)

This commit is contained in:
2025-09-26 02:23:25 +02:00
parent ab0fbbce79
commit 00c5c2bd63
7 changed files with 218 additions and 3 deletions

76
Cargo.lock generated
View File

@@ -63,6 +63,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "021e862c184ae977658b36c4500f7feac3221ca5da43e3f25bd04ab6c79a29b5" checksum = "021e862c184ae977658b36c4500f7feac3221ca5da43e3f25bd04ab6c79a29b5"
dependencies = [ dependencies = [
"axum-core", "axum-core",
"base64",
"bytes", "bytes",
"form_urlencoded", "form_urlencoded",
"futures-util", "futures-util",
@@ -82,8 +83,10 @@ dependencies = [
"serde_json", "serde_json",
"serde_path_to_error", "serde_path_to_error",
"serde_urlencoded", "serde_urlencoded",
"sha1",
"sync_wrapper", "sync_wrapper",
"tokio", "tokio",
"tokio-tungstenite",
"tower", "tower",
"tower-layer", "tower-layer",
"tower-service", "tower-service",
@@ -286,6 +289,12 @@ dependencies = [
"parking_lot_core", "parking_lot_core",
] ]
[[package]]
name = "data-encoding"
version = "2.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2a2330da5de22e8a3cb63252ce2abb30116bf5265e89c0e01bc17015ce30a476"
[[package]] [[package]]
name = "deranged" name = "deranged"
version = "0.5.3" version = "0.5.3"
@@ -366,6 +375,23 @@ version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e"
[[package]]
name = "futures-macro"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "futures-sink"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e575fab7d1e0dcb8d0c7bcf9a63ee213816ab51902e6d244a95819acacf1d4f7"
[[package]] [[package]]
name = "futures-task" name = "futures-task"
version = "0.3.31" version = "0.3.31"
@@ -379,9 +405,12 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81"
dependencies = [ dependencies = [
"futures-core", "futures-core",
"futures-macro",
"futures-sink",
"futures-task", "futures-task",
"pin-project-lite", "pin-project-lite",
"pin-utils", "pin-utils",
"slab",
] ]
[[package]] [[package]]
@@ -460,6 +489,7 @@ dependencies = [
"chrono", "chrono",
"dashmap", "dashmap",
"dotenvy", "dotenvy",
"futures-util",
"jsonwebtoken", "jsonwebtoken",
"r2d2", "r2d2",
"r2d2_sqlite", "r2d2_sqlite",
@@ -1050,6 +1080,17 @@ dependencies = [
"serde", "serde",
] ]
[[package]]
name = "sha1"
version = "0.10.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba"
dependencies = [
"cfg-if",
"cpufeatures",
"digest",
]
[[package]] [[package]]
name = "shlex" name = "shlex"
version = "1.3.0" version = "1.3.0"
@@ -1204,6 +1245,18 @@ dependencies = [
"syn", "syn",
] ]
[[package]]
name = "tokio-tungstenite"
version = "0.26.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a9daff607c6d2bf6c16fd681ccb7eecc83e4e2cdc1ca067ffaadfca5de7f084"
dependencies = [
"futures-util",
"log",
"tokio",
"tungstenite",
]
[[package]] [[package]]
name = "tower" name = "tower"
version = "0.5.2" version = "0.5.2"
@@ -1252,6 +1305,23 @@ dependencies = [
"once_cell", "once_cell",
] ]
[[package]]
name = "tungstenite"
version = "0.26.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4793cb5e56680ecbb1d843515b23b6de9a75eb04b66643e256a396d43be33c13"
dependencies = [
"bytes",
"data-encoding",
"http",
"httparse",
"log",
"rand",
"sha1",
"thiserror",
"utf-8",
]
[[package]] [[package]]
name = "typenum" name = "typenum"
version = "1.18.0" version = "1.18.0"
@@ -1270,6 +1340,12 @@ version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1"
[[package]]
name = "utf-8"
version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9"
[[package]] [[package]]
name = "uuid" name = "uuid"
version = "1.18.1" version = "1.18.1"

View File

@@ -5,7 +5,7 @@ edition = "2024"
publish = false publish = false
[dependencies] [dependencies]
axum = "0.8.4" axum = {version = "0.8.4", features = ["ws"]}
tokio = { version = "1", features = ["full"] } tokio = { version = "1", features = ["full"] }
serde = { version = "1.0", features = ["derive"] } serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0" serde_json = "1.0"
@@ -23,6 +23,7 @@ chrono = "0.4.42"
anyhow = "1.0.100" anyhow = "1.0.100"
argon2 = {version = "0.5.3"} argon2 = {version = "0.5.3"}
rand_core = {version = "0.6.4", features = ["getrandom"]} rand_core = {version = "0.6.4", features = ["getrandom"]}
futures-util = {version = "0.3.31"}

View File

@@ -1,7 +1,9 @@
use axum::serve; use axum::serve;
use axum::Extension; use axum::Extension;
use axum::extract::{ws::{Message, WebSocket, WebSocketUpgrade}, State};
use jsonwebtoken::{DecodingKey, EncodingKey}; use jsonwebtoken::{DecodingKey, EncodingKey};
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tokio::sync::mpsc;
mod utils; mod utils;
mod routes; mod routes;
@@ -9,6 +11,8 @@ mod rooms;
mod chat; mod chat;
use r2d2::{Pool}; use r2d2::{Pool};
use r2d2_sqlite::SqliteConnectionManager; use r2d2_sqlite::SqliteConnectionManager;
use dashmap::DashMap;
use std::sync::Arc;
use crate::utils::db_pool::{HotelPool,AppState}; use crate::utils::db_pool::{HotelPool,AppState};
use routes::create_router; use routes::create_router;
@@ -16,6 +20,7 @@ use crate::utils::auth::JwtKeys;
#[tokio::main] #[tokio::main]
async fn main() -> std::io::Result<()> { async fn main() -> std::io::Result<()> {
@@ -26,9 +31,17 @@ async fn main() -> std::io::Result<()> {
.build(logs_manager) .build(logs_manager)
.expect("Failed to build logs pool"); .expect("Failed to build logs pool");
type UserMap = DashMap<i32, mpsc::UnboundedSender<Message>>;
/// hotel_id → users
type HotelMap = DashMap<i32, Arc<UserMap>>;
/// global map of all hotels
type WsMap = Arc<HotelMap>;
let state = AppState { let state = AppState {
hotel_pools, hotel_pools,
logs_pool, logs_pool,
ws_map: Arc::new(DashMap::new()),
//jwt_secret: "your_jwt_secret_key s".to_string(), // better: load from env var //jwt_secret: "your_jwt_secret_key s".to_string(), // better: load from env var
}; };
@@ -38,6 +51,8 @@ async fn main() -> std::io::Result<()> {
decoding: DecodingKey::from_secret(jwt_secret.as_ref()), decoding: DecodingKey::from_secret(jwt_secret.as_ref()),
}; };
let app = create_router(state) let app = create_router(state)
.layer(Extension(jwt_keys)); .layer(Extension(jwt_keys));

View File

@@ -2,13 +2,25 @@ use std::sync::Arc;
use dashmap::DashMap; use dashmap::DashMap;
use r2d2::{Pool}; use r2d2::{Pool};
use r2d2_sqlite::SqliteConnectionManager; use r2d2_sqlite::SqliteConnectionManager;
use tokio::sync::mpsc;
use axum::extract::{ws::{Message, WebSocket, WebSocketUpgrade}, State};
type HotelId = i32; // or i32 if you want numeric ids type HotelId = i32; // or i32 if you want numeric ids
/// Type alias: user_id → sender to that user
type UserMap = DashMap<i32, mpsc::UnboundedSender<Message>>;
/// hotel_id → users
type HotelMap = DashMap<i32, Arc<UserMap>>;
/// global map of all hotels
type WsMap = Arc<HotelMap>;
/// Type alias: user_id → sender to that user
#[derive(Clone)] #[derive(Clone)]
pub struct AppState { pub struct AppState {
pub hotel_pools: HotelPool, pub hotel_pools: HotelPool,
pub logs_pool: Pool<SqliteConnectionManager>, pub logs_pool: Pool<SqliteConnectionManager>,
pub ws_map: WsMap,
} }

View File

@@ -1,3 +1,4 @@
pub mod db_pool; pub mod db_pool;
pub mod auth; pub mod auth;
pub mod routes; pub mod routes;
pub mod websocket;

View File

@@ -5,7 +5,7 @@ use axum::{
use crate::utils::auth::*; use crate::utils::auth::*;
use crate::utils::db_pool::{HotelPool, AppState, }; use crate::utils::db_pool::{HotelPool, AppState, };
use crate::utils::websocket::ws_handler;
// ROOTS // ROOTS
pub fn utils_routes() -> Router<AppState> { pub fn utils_routes() -> Router<AppState> {
@@ -13,6 +13,8 @@ pub fn utils_routes() -> Router<AppState> {
Router::new() Router::new()
.route("/login", put(clean_auth_loging)) .route("/login", put(clean_auth_loging))
.route("/register", put(register_user)) .route("/register", put(register_user))
.route("/ws/{hotel_id}/{user_id}", get(ws_handler))
.route("/tokentest", put(token_tester)) .route("/tokentest", put(token_tester))
//.with_state(state) //.with_state(state)
} }

108
src/utils/websocket.rs Normal file
View File

@@ -0,0 +1,108 @@
use dashmap::DashMap;
use std::sync::Arc;
use axum::extract::{ws::{Message, WebSocket, WebSocketUpgrade}, State};
use tokio::sync::mpsc;
use axum::extract::Path;
use axum::response::IntoResponse;
//use futures_util::stream::stream::StreamExt;
use futures_util::{StreamExt, SinkExt};
use crate::utils::db_pool::{HotelPool,AppState};
/// Type alias: user_id → sender to that user
type UserMap = DashMap<i32, mpsc::UnboundedSender<Message>>;
/// hotel_id → users
type HotelMap = DashMap<i32, Arc<UserMap>>;
/// global map of all hotels
type WsMap = Arc<HotelMap>;
/// Type alias: user_id → sender to that user
async fn handle_socket(
mut socket: WebSocket,
state: AppState,
hotel_id: i32,
user_id: i32,
) {
// channel for sending messages TO this client
let (tx, mut rx) = mpsc::unbounded_channel::<Message>();
// insert into hotel → user map
let user_map = state
.ws_map
.entry(hotel_id)
.or_insert_with(|| Arc::new(UserMap::new()))
.clone();
user_map.insert(user_id, tx);
// ✅ print after upgrading
print_ws_state(&state);
// split socket into sender/receiver
let (mut sender, mut receiver) = socket.split();
// task for sending messages from server to client
let mut rx_task = tokio::spawn(async move {
while let Some(msg) = rx.recv().await {
if sender.send(msg).await.is_err() {
break;
}
}
});
// task for receiving messages from client
let state_clone = state.clone();
let mut recv_task = tokio::spawn(async move {
while let Some(Ok(msg)) = receiver.next().await {
match msg {
Message::Text(text) => {
println!("Hotel {hotel_id}, User {user_id} said: {text}");
// echo back just as an example
if let Some(hotel_entry) = state_clone.ws_map.get(&hotel_id) {
if let Some(sender) = hotel_entry.get(&user_id) {
let _ = sender.send(Message::Text(format!("echo: {text}").into()));
}
}
}
Message::Close(_) => break,
_ => {}
}
}
});
// wait for either side to finish
tokio::select! {
_ = (&mut rx_task) => recv_task.abort(),
_ = (&mut recv_task) => rx_task.abort(),
}
// cleanup
user_map.remove(&user_id);
if user_map.is_empty() {
state.ws_map.remove(&hotel_id);
}
}
pub async fn ws_handler(
ws: WebSocketUpgrade,
State(state): State<AppState>,
Path((hotel_id, user_id)): Path<(i32, i32)>,
) -> impl IntoResponse {
ws.on_upgrade(move |socket| handle_socket(socket, state, hotel_id, user_id))
}
fn print_ws_state(state: &AppState) {
println!("--- Current WebSocket state ---");
for hotel_entry in state.ws_map.iter() {
let hotel_id = *hotel_entry.key();
let user_map = hotel_entry.value();
let users: Vec<_> = user_map.iter().map(|u| *u.key()).collect();
println!("Hotel {hotel_id}: users {:?}", users);
}
println!("--------------------------------");
}