130 lines
3.8 KiB
Rust
130 lines
3.8 KiB
Rust
use axum::extract::Path;
|
|
use axum::response::IntoResponse;
|
|
use axum::{
|
|
Extension,
|
|
extract::{
|
|
State,
|
|
ws::{Message, WebSocket, WebSocketUpgrade},
|
|
},
|
|
};
|
|
use dashmap::DashMap;
|
|
use reqwest::StatusCode;
|
|
use std::sync::Arc;
|
|
use tokio::sync::mpsc;
|
|
//use futures_util::stream::stream::StreamExt;
|
|
use futures_util::{SinkExt, StreamExt};
|
|
|
|
use crate::utils::{
|
|
auth::{AuthClaims, JwtKeys, auth_claims_from_token},
|
|
db_pool::{AppState, HotelPool},
|
|
};
|
|
|
|
/// Type alias: user_id → sender to that user
|
|
pub type UserMap = DashMap<i32, mpsc::UnboundedSender<Message>>;
|
|
/// hotel_id → users
|
|
pub type HotelMap = DashMap<i32, Arc<UserMap>>;
|
|
/// global map of all hotels
|
|
pub type WsMap = Arc<HotelMap>;
|
|
/// Type alias: user_id → sender to that user
|
|
|
|
async fn handle_socket(mut socket: WebSocket, state: AppState, hotel_id: i32, user_id: i32) {
|
|
// channel for sending messages TO this client
|
|
let (tx, mut rx) = mpsc::unbounded_channel::<Message>();
|
|
|
|
// insert into hotel → user map
|
|
let user_map = state
|
|
.ws_map
|
|
.entry(hotel_id)
|
|
.or_insert_with(|| Arc::new(UserMap::new()))
|
|
.clone();
|
|
|
|
user_map.insert(user_id, tx);
|
|
|
|
// ✅ print after upgrading
|
|
print_ws_state(&state);
|
|
|
|
// split socket into sender/receiver
|
|
let (mut sender, mut receiver) = socket.split();
|
|
|
|
// task for sending messages from server to client
|
|
let mut rx_task = tokio::spawn(async move {
|
|
while let Some(msg) = rx.recv().await {
|
|
if sender.send(msg).await.is_err() {
|
|
break;
|
|
}
|
|
}
|
|
});
|
|
|
|
// task for receiving messages from client
|
|
let state_clone = state.clone();
|
|
let mut recv_task = tokio::spawn(async move {
|
|
while let Some(Ok(msg)) = receiver.next().await {
|
|
match msg {
|
|
Message::Text(text) => {
|
|
println!("Hotel {hotel_id}, User {user_id} said: {text}");
|
|
// echo back just as an example
|
|
if let Some(hotel_entry) = state_clone.ws_map.get(&hotel_id) {
|
|
if let Some(sender) = hotel_entry.get(&user_id) {
|
|
let _ = sender.send(Message::Text(format!("echo: {text}").into()));
|
|
}
|
|
}
|
|
}
|
|
Message::Close(_) => break,
|
|
_ => {}
|
|
}
|
|
}
|
|
});
|
|
|
|
// wait for either side to finish
|
|
tokio::select! {
|
|
_ = (&mut rx_task) => recv_task.abort(),
|
|
_ = (&mut recv_task) => rx_task.abort(),
|
|
}
|
|
|
|
// cleanup
|
|
user_map.remove(&user_id);
|
|
if user_map.is_empty() {
|
|
state.ws_map.remove(&hotel_id);
|
|
}
|
|
}
|
|
|
|
pub async fn ws_handler(
|
|
//AuthClaims {user_id, hotel_id}: AuthClaims,
|
|
ws: WebSocketUpgrade,
|
|
Extension(keys): Extension<JwtKeys>,
|
|
State(state): State<AppState>,
|
|
Path((req_token)): Path<(String)>,
|
|
) -> impl IntoResponse {
|
|
let token = req_token;
|
|
|
|
let claims = match auth_claims_from_token(&token, &keys) {
|
|
Err(_) => {
|
|
print!("error during auth claims processing");
|
|
return StatusCode::UNAUTHORIZED.into_response();
|
|
}
|
|
Ok(c) => c,
|
|
};
|
|
|
|
print!("{token}, web socket tried to connect",);
|
|
|
|
/*
|
|
let claims = match auth_claims_from_token(&token, &keys) {
|
|
Ok(c) => c,
|
|
Err(_) => return StatusCode::UNAUTHORIZED.into_response(),
|
|
};
|
|
*/
|
|
|
|
ws.on_upgrade(move |socket| handle_socket(socket, state, claims.hotel_id, claims.user_id))
|
|
}
|
|
|
|
fn print_ws_state(state: &AppState) {
|
|
println!("--- Current WebSocket state ---");
|
|
for hotel_entry in state.ws_map.iter() {
|
|
let hotel_id = *hotel_entry.key();
|
|
let user_map = hotel_entry.value();
|
|
let users: Vec<_> = user_map.iter().map(|u| *u.key()).collect();
|
|
println!("Hotel {hotel_id}: users {:?}", users);
|
|
}
|
|
println!("--------------------------------");
|
|
}
|