Files
hotel_api/src/utils/websocket.rs
Romain Mallard e5a1d36654
Some checks failed
Deploy API / build-and-deploy (push) Failing after 5s
rust fmt and some cleaning
2026-03-11 14:22:46 +01:00

130 lines
3.8 KiB
Rust

use axum::extract::Path;
use axum::response::IntoResponse;
use axum::{
Extension,
extract::{
State,
ws::{Message, WebSocket, WebSocketUpgrade},
},
};
use dashmap::DashMap;
use reqwest::StatusCode;
use std::sync::Arc;
use tokio::sync::mpsc;
//use futures_util::stream::stream::StreamExt;
use futures_util::{SinkExt, StreamExt};
use crate::utils::{
auth::{AuthClaims, JwtKeys, auth_claims_from_token},
db_pool::{AppState, HotelPool},
};
/// Type alias: user_id → sender to that user
pub type UserMap = DashMap<i32, mpsc::UnboundedSender<Message>>;
/// hotel_id → users
pub type HotelMap = DashMap<i32, Arc<UserMap>>;
/// global map of all hotels
pub type WsMap = Arc<HotelMap>;
/// Type alias: user_id → sender to that user
async fn handle_socket(mut socket: WebSocket, state: AppState, hotel_id: i32, user_id: i32) {
// channel for sending messages TO this client
let (tx, mut rx) = mpsc::unbounded_channel::<Message>();
// insert into hotel → user map
let user_map = state
.ws_map
.entry(hotel_id)
.or_insert_with(|| Arc::new(UserMap::new()))
.clone();
user_map.insert(user_id, tx);
// ✅ print after upgrading
print_ws_state(&state);
// split socket into sender/receiver
let (mut sender, mut receiver) = socket.split();
// task for sending messages from server to client
let mut rx_task = tokio::spawn(async move {
while let Some(msg) = rx.recv().await {
if sender.send(msg).await.is_err() {
break;
}
}
});
// task for receiving messages from client
let state_clone = state.clone();
let mut recv_task = tokio::spawn(async move {
while let Some(Ok(msg)) = receiver.next().await {
match msg {
Message::Text(text) => {
println!("Hotel {hotel_id}, User {user_id} said: {text}");
// echo back just as an example
if let Some(hotel_entry) = state_clone.ws_map.get(&hotel_id) {
if let Some(sender) = hotel_entry.get(&user_id) {
let _ = sender.send(Message::Text(format!("echo: {text}").into()));
}
}
}
Message::Close(_) => break,
_ => {}
}
}
});
// wait for either side to finish
tokio::select! {
_ = (&mut rx_task) => recv_task.abort(),
_ = (&mut recv_task) => rx_task.abort(),
}
// cleanup
user_map.remove(&user_id);
if user_map.is_empty() {
state.ws_map.remove(&hotel_id);
}
}
pub async fn ws_handler(
//AuthClaims {user_id, hotel_id}: AuthClaims,
ws: WebSocketUpgrade,
Extension(keys): Extension<JwtKeys>,
State(state): State<AppState>,
Path((req_token)): Path<(String)>,
) -> impl IntoResponse {
let token = req_token;
let claims = match auth_claims_from_token(&token, &keys) {
Err(_) => {
print!("error during auth claims processing");
return StatusCode::UNAUTHORIZED.into_response();
}
Ok(c) => c,
};
print!("{token}, web socket tried to connect",);
/*
let claims = match auth_claims_from_token(&token, &keys) {
Ok(c) => c,
Err(_) => return StatusCode::UNAUTHORIZED.into_response(),
};
*/
ws.on_upgrade(move |socket| handle_socket(socket, state, claims.hotel_id, claims.user_id))
}
fn print_ws_state(state: &AppState) {
println!("--- Current WebSocket state ---");
for hotel_entry in state.ws_map.iter() {
let hotel_id = *hotel_entry.key();
let user_map = hotel_entry.value();
let users: Vec<_> = user_map.iter().map(|u| *u.key()).collect();
println!("Hotel {hotel_id}: users {:?}", users);
}
println!("--------------------------------");
}