multi-hotel-refactor #3

Merged
Rominou merged 27 commits from multi-hotel-refactor into master 2026-03-11 13:32:43 +00:00
12 changed files with 95 additions and 60 deletions
Showing only changes of commit 170fedbcbd - Show all commits

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -322,7 +322,7 @@ pub async fn send_message(
if let Some(hotel_users) = state.ws_map.get(&hotel_id) { if let Some(hotel_users) = state.ws_map.get(&hotel_id) {
let update_msg = serde_json::json!({ let update_msg = serde_json::json!({
"event-type": "chat-message", "event_type": "chat_message",
"conv_id": payload.conv_id, "conv_id": payload.conv_id,
"sender": user_id, "sender": user_id,
"content": payload.message, "content": payload.message,

View File

@@ -77,7 +77,7 @@ pub async fn update_inventory_item(
(StatusCode::OK, format!("updated item history")) (StatusCode::OK, format!("updated item history"))
} }
Ok(_) => (StatusCode::NOT_FOUND, "No room found".to_string()), Ok(_) => (StatusCode::NOT_FOUND, "No item found, err : {_}".to_string()),
Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, format!("Error from DB: {err}")), Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, format!("Error from DB: {err}")),
} }
@@ -104,7 +104,7 @@ pub async fn get_inventory_item(
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, "Pool error".to_string()), Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, "Pool error".to_string()),
}; };
let mut stmt = match conn.prepare("SELECT id, amount, item_name, user_id FROM inventory") { let mut stmt = match conn.prepare("SELECT id, amount, item_name, user_id, updated_at FROM inventory") {
Ok(s) => s, Ok(s) => s,
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, "Statement error".to_string()), Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, "Statement error".to_string()),
}; };
@@ -120,7 +120,7 @@ pub async fn get_inventory_item(
let item = InventoryItems { let item = InventoryItems {
id: row.get("id").unwrap_or_default(), id: row.get("id").unwrap_or_default(),
amount: row.get("amount").unwrap_or_default(), amount: row.get("amount").unwrap_or_default(),
name: row.get("name").unwrap_or_default(), name: row.get("item_name").unwrap_or_default(),
user_id: row.get("user_id").unwrap_or_default(), user_id: row.get("user_id").unwrap_or_default(),
updated_at: row.get("updated_at").unwrap_or_default(), updated_at: row.get("updated_at").unwrap_or_default(),
}; };

View File

@@ -50,7 +50,7 @@ async fn main() -> std::io::Result<()> {
dotenv().ok(); dotenv().ok();
std::panic::set_hook(Box::new(|info| { std::panic::set_hook(Box::new(|info| {
let msg = format!("Rust panic: {}", info); let msg = format!("Rust panic: {}", info);
// Use blocking client so the process can't exit before sending // Use blocking client so the process can't exit before sending
@@ -63,9 +63,9 @@ std::panic::set_hook(Box::new(|info| {
.post("https://discord.com/api/webhooks/1440912618205347891/Ekg89krDoPm41kA27LA3gXgNWmMWvCCtziYIUsjqaY22Jnw4a6IWhZOht0in5JjnPX-W") .post("https://discord.com/api/webhooks/1440912618205347891/Ekg89krDoPm41kA27LA3gXgNWmMWvCCtziYIUsjqaY22Jnw4a6IWhZOht0in5JjnPX-W")
.json(&payload) .json(&payload)
.send(); .send();
})); }));
//panic!("crash-test"); //panic!("crash-test");
let hotel_pools = HotelPool::new(); let hotel_pools = HotelPool::new();
let logs_manager = SqliteConnectionManager::file("db/auth_copy_2.sqlite"); let logs_manager = SqliteConnectionManager::file("db/auth_copy_2.sqlite");
@@ -92,12 +92,12 @@ std::panic::set_hook(Box::new(|info| {
decoding: DecodingKey::from_secret(jwt_secret.as_ref()), decoding: DecodingKey::from_secret(jwt_secret.as_ref()),
}; };
let allowed_origins = vec![ let allowed_origins = vec![
"http://82.66.253.209", "http://82.66.253.209",
"http://localhost:5173", "http://localhost:5173",
]; ];
let cors = CorsLayer::very_permissive() let cors = CorsLayer::very_permissive()
.allow_credentials(true) .allow_credentials(true)
.allow_methods([Method::GET, Method::POST, Method::PUT, Method::OPTIONS]) .allow_methods([Method::GET, Method::POST, Method::PUT, Method::OPTIONS])
.allow_headers([CONTENT_TYPE, AUTHORIZATION]); .allow_headers([CONTENT_TYPE, AUTHORIZATION]);
@@ -110,6 +110,6 @@ let cors = CorsLayer::very_permissive()
Ok(()) Ok(())
} }
async fn handler() -> &'static str { async fn handler() -> &'static str {
"Hiii from localhost" "Hiii from localhost"
} }

View File

@@ -83,7 +83,7 @@ pub async fn clean_db_update(
} }
if let Some(hotel_users) = state.ws_map.get(&hotel_id) { if let Some(hotel_users) = state.ws_map.get(&hotel_id) {
let update_msg = json!({ let update_msg = json!({
"event-type": "room-update", "event_type": "room_update",
"room_id": room_id, "room_id": room_id,
"status": payload.status, "status": payload.status,
"updated_by": user_id, "updated_by": user_id,

View File

@@ -47,6 +47,8 @@ pub async fn token_tester(
) )
} }
pub struct AuthUser(pub Claims); //?? pub struct AuthUser(pub Claims); //??
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@@ -56,44 +58,51 @@ pub struct AuthClaims {
//pub username: String, //pub username: String,
} }
impl<S> FromRequestParts<S> for AuthClaims pub fn auth_claims_from_token(
where token: &str,
S: Send + Sync + 'static, keys: &JwtKeys,
AppState: Clone + Send + Sync + 'static, AppState: FromRef<S> ) -> Result<AuthClaims, (StatusCode, String)> {
{
type Rejection = (StatusCode, String);
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
// We assume your state has a `jwt_secret` field
let Extension(keys): Extension<JwtKeys> =
Extension::from_request_parts(parts, state).await.map_err(|_| (StatusCode::UNAUTHORIZED, "Missing keys".to_string()))?;
// 1⃣ Extract the token from the Authorization header
let auth_header = parts
.headers
.get("Authorization")
.ok_or((StatusCode::UNAUTHORIZED, "Missing Authorization header".to_string()))?
.to_str()
.map_err(|_| (StatusCode::BAD_REQUEST, "Invalid Authorization header".to_string()))?;
// Bearer token?
let token = auth_header
.strip_prefix("Bearer ")
.ok_or((StatusCode::BAD_REQUEST, "Expected Bearer token".to_string()))?;
// 2⃣ Decode the token
let token_data = decode::<Claims>( let token_data = decode::<Claims>(
token, token,
&keys.decoding, &keys.decoding,
&Validation::new(Algorithm::HS256), &Validation::new(Algorithm::HS256),
).map_err(|_| (StatusCode::UNAUTHORIZED, "Invalid token".to_string()))?; ).map_err(|_| (StatusCode::UNAUTHORIZED, "Invalid token".into()))?;
Ok(AuthClaims { Ok(AuthClaims {
user_id: token_data.claims.id, user_id: token_data.claims.id,
hotel_id: token_data.claims.hotel_id, hotel_id: token_data.claims.hotel_id,
//username: token_data.claims.username,
}) })
}
impl<S> FromRequestParts<S> for AuthClaims
where
S: Send + Sync + 'static,
AppState: FromRef<S>,
{
type Rejection = (StatusCode, String);
async fn from_request_parts(
parts: &mut Parts,
state: &S,
) -> Result<Self, Self::Rejection> {
let Extension(keys): Extension<JwtKeys> =
Extension::from_request_parts(parts, state)
.await
.map_err(|_| (StatusCode::UNAUTHORIZED, "Missing keys".into()))?;
let auth_header = parts
.headers
.get(axum::http::header::AUTHORIZATION)
.ok_or((StatusCode::UNAUTHORIZED, "Missing Authorization header".into()))?
.to_str()
.map_err(|_| (StatusCode::BAD_REQUEST, "Invalid Authorization header".into()))?;
let token = auth_header
.strip_prefix("Bearer ")
.ok_or((StatusCode::BAD_REQUEST, "Expected Bearer token".into()))?;
auth_claims_from_token(token, &keys)
} }
} }

View File

@@ -16,7 +16,7 @@ pub fn utils_routes() -> Router<AppState> {
Router::new() Router::new()
.route("/login", put(clean_auth_loging)) .route("/login", put(clean_auth_loging))
.route("/register", put(register_user)) .route("/register", put(register_user))
.route("/ws/", get(ws_handler))
.route("/tokentest", put(token_tester)) .route("/tokentest", put(token_tester))
.route("/force_update_password", put(force_update_password)) .route("/force_update_password", put(force_update_password))
.route("/update_password", put(update_password)) .route("/update_password", put(update_password))
@@ -27,7 +27,7 @@ pub fn utils_routes() -> Router<AppState> {
.route("/logout_single_device", post(logout_from_single_device)) .route("/logout_single_device", post(logout_from_single_device))
.route("/logout_all_devices", post(logout_from_all_devices)) .route("/logout_all_devices", post(logout_from_all_devices))
.route("/ws/{req_token}", get(ws_handler))
//.with_state(state) //.with_state(state)
} }

View File

@@ -1,13 +1,14 @@
use dashmap::DashMap; use dashmap::DashMap;
use reqwest::StatusCode;
use std::sync::Arc; use std::sync::Arc;
use axum::extract::{ws::{Message, WebSocket, WebSocketUpgrade}, State}; use axum::{Extension, extract::{State, ws::{Message, WebSocket, WebSocketUpgrade}}};
use tokio::sync::mpsc; use tokio::sync::mpsc;
use axum::extract::Path; use axum::extract::Path;
use axum::response::IntoResponse; use axum::response::IntoResponse;
//use futures_util::stream::stream::StreamExt; //use futures_util::stream::stream::StreamExt;
use futures_util::{StreamExt, SinkExt}; use futures_util::{StreamExt, SinkExt};
use crate::utils::{auth::AuthClaims, db_pool::{AppState, HotelPool}}; use crate::utils::{auth::{AuthClaims, JwtKeys, auth_claims_from_token}, db_pool::{AppState, HotelPool}};
@@ -89,12 +90,37 @@ async fn handle_socket(
} }
pub async fn ws_handler( pub async fn ws_handler(
AuthClaims {user_id, hotel_id}: AuthClaims, //AuthClaims {user_id, hotel_id}: AuthClaims,
ws: WebSocketUpgrade, ws: WebSocketUpgrade,
Extension(keys): Extension<JwtKeys>,
State(state): State<AppState>, State(state): State<AppState>,
//Path((hotel_id, user_id)): Path<(i32, i32)>, Path((req_token)): Path<(String)>,
) -> impl IntoResponse { ) -> impl IntoResponse {
ws.on_upgrade(move |socket| handle_socket(socket, state, hotel_id, user_id))
let token = req_token;
let claims = match auth_claims_from_token(&token, &keys) {
Err(_) => {
print!("error during auth claims processing");
return StatusCode::UNAUTHORIZED.into_response();
}
Ok(c) => c
};
print!("{token}, web socket tried to connect", );
/*
let claims = match auth_claims_from_token(&token, &keys) {
Ok(c) => c,
Err(_) => return StatusCode::UNAUTHORIZED.into_response(),
};
*/
ws.on_upgrade(move |socket| handle_socket(socket, state, claims.hotel_id, claims.user_id))
} }
fn print_ws_state(state: &AppState) { fn print_ws_state(state: &AppState) {