use axum::extract::Path; use axum::response::IntoResponse; use axum::{ Extension, extract::{ State, ws::{Message, WebSocket, WebSocketUpgrade}, }, }; use dashmap::DashMap; use reqwest::StatusCode; use std::sync::Arc; use tokio::sync::mpsc; //use futures_util::stream::stream::StreamExt; use futures_util::{SinkExt, StreamExt}; use crate::utils::{ auth::{AuthClaims, JwtKeys, auth_claims_from_token}, db_pool::{AppState, HotelPool}, }; /// Type alias: user_id → sender to that user pub type UserMap = DashMap>; /// hotel_id → users pub type HotelMap = DashMap>; /// global map of all hotels pub type WsMap = Arc; /// Type alias: user_id → sender to that user async fn handle_socket(mut socket: WebSocket, state: AppState, hotel_id: i32, user_id: i32) { // channel for sending messages TO this client let (tx, mut rx) = mpsc::unbounded_channel::(); // insert into hotel → user map let user_map = state .ws_map .entry(hotel_id) .or_insert_with(|| Arc::new(UserMap::new())) .clone(); user_map.insert(user_id, tx); // ✅ print after upgrading print_ws_state(&state); // split socket into sender/receiver let (mut sender, mut receiver) = socket.split(); // task for sending messages from server to client let mut rx_task = tokio::spawn(async move { while let Some(msg) = rx.recv().await { if sender.send(msg).await.is_err() { break; } } }); // task for receiving messages from client let state_clone = state.clone(); let mut recv_task = tokio::spawn(async move { while let Some(Ok(msg)) = receiver.next().await { match msg { Message::Text(text) => { println!("Hotel {hotel_id}, User {user_id} said: {text}"); // echo back just as an example if let Some(hotel_entry) = state_clone.ws_map.get(&hotel_id) { if let Some(sender) = hotel_entry.get(&user_id) { let _ = sender.send(Message::Text(format!("echo: {text}").into())); } } } Message::Close(_) => break, _ => {} } } }); // wait for either side to finish tokio::select! { _ = (&mut rx_task) => recv_task.abort(), _ = (&mut recv_task) => rx_task.abort(), } // cleanup user_map.remove(&user_id); if user_map.is_empty() { state.ws_map.remove(&hotel_id); } } pub async fn ws_handler( //AuthClaims {user_id, hotel_id}: AuthClaims, ws: WebSocketUpgrade, Extension(keys): Extension, State(state): State, Path((req_token)): Path<(String)>, ) -> impl IntoResponse { let token = req_token; let claims = match auth_claims_from_token(&token, &keys) { Err(_) => { print!("error during auth claims processing"); return StatusCode::UNAUTHORIZED.into_response(); } Ok(c) => c, }; print!("{token}, web socket tried to connect",); /* let claims = match auth_claims_from_token(&token, &keys) { Ok(c) => c, Err(_) => return StatusCode::UNAUTHORIZED.into_response(), }; */ ws.on_upgrade(move |socket| handle_socket(socket, state, claims.hotel_id, claims.user_id)) } fn print_ws_state(state: &AppState) { println!("--- Current WebSocket state ---"); for hotel_entry in state.ws_map.iter() { let hotel_id = *hotel_entry.key(); let user_map = hotel_entry.value(); let users: Vec<_> = user_map.iter().map(|u| *u.key()).collect(); println!("Hotel {hotel_id}: users {:?}", users); } println!("--------------------------------"); }