refactoring token -> claim processing

This commit is contained in:
2026-01-03 17:15:34 +01:00
parent c0d70077d7
commit 170fedbcbd
12 changed files with 95 additions and 60 deletions

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -322,7 +322,7 @@ pub async fn send_message(
if let Some(hotel_users) = state.ws_map.get(&hotel_id) { if let Some(hotel_users) = state.ws_map.get(&hotel_id) {
let update_msg = serde_json::json!({ let update_msg = serde_json::json!({
"event-type": "chat-message", "event_type": "chat_message",
"conv_id": payload.conv_id, "conv_id": payload.conv_id,
"sender": user_id, "sender": user_id,
"content": payload.message, "content": payload.message,

View File

@@ -77,7 +77,7 @@ pub async fn update_inventory_item(
(StatusCode::OK, format!("updated item history")) (StatusCode::OK, format!("updated item history"))
} }
Ok(_) => (StatusCode::NOT_FOUND, "No room found".to_string()), Ok(_) => (StatusCode::NOT_FOUND, "No item found, err : {_}".to_string()),
Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, format!("Error from DB: {err}")), Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, format!("Error from DB: {err}")),
} }
@@ -104,7 +104,7 @@ pub async fn get_inventory_item(
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, "Pool error".to_string()), Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, "Pool error".to_string()),
}; };
let mut stmt = match conn.prepare("SELECT id, amount, item_name, user_id FROM inventory") { let mut stmt = match conn.prepare("SELECT id, amount, item_name, user_id, updated_at FROM inventory") {
Ok(s) => s, Ok(s) => s,
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, "Statement error".to_string()), Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, "Statement error".to_string()),
}; };
@@ -120,7 +120,7 @@ pub async fn get_inventory_item(
let item = InventoryItems { let item = InventoryItems {
id: row.get("id").unwrap_or_default(), id: row.get("id").unwrap_or_default(),
amount: row.get("amount").unwrap_or_default(), amount: row.get("amount").unwrap_or_default(),
name: row.get("name").unwrap_or_default(), name: row.get("item_name").unwrap_or_default(),
user_id: row.get("user_id").unwrap_or_default(), user_id: row.get("user_id").unwrap_or_default(),
updated_at: row.get("updated_at").unwrap_or_default(), updated_at: row.get("updated_at").unwrap_or_default(),
}; };

View File

@@ -50,22 +50,22 @@ async fn main() -> std::io::Result<()> {
dotenv().ok(); dotenv().ok();
std::panic::set_hook(Box::new(|info| { std::panic::set_hook(Box::new(|info| {
let msg = format!("Rust panic: {}", info); let msg = format!("Rust panic: {}", info);
// Use blocking client so the process can't exit before sending // Use blocking client so the process can't exit before sending
let payload = serde_json::json!({ let payload = serde_json::json!({
"content": msg "content": msg
}); });
let client = reqwest::blocking::Client::new(); let client = reqwest::blocking::Client::new();
let _ = client let _ = client
.post("https://discord.com/api/webhooks/1440912618205347891/Ekg89krDoPm41kA27LA3gXgNWmMWvCCtziYIUsjqaY22Jnw4a6IWhZOht0in5JjnPX-W") .post("https://discord.com/api/webhooks/1440912618205347891/Ekg89krDoPm41kA27LA3gXgNWmMWvCCtziYIUsjqaY22Jnw4a6IWhZOht0in5JjnPX-W")
.json(&payload) .json(&payload)
.send(); .send();
})); }));
//panic!("crash-test"); //panic!("crash-test");
let hotel_pools = HotelPool::new(); let hotel_pools = HotelPool::new();
let logs_manager = SqliteConnectionManager::file("db/auth_copy_2.sqlite"); let logs_manager = SqliteConnectionManager::file("db/auth_copy_2.sqlite");
@@ -92,15 +92,15 @@ std::panic::set_hook(Box::new(|info| {
decoding: DecodingKey::from_secret(jwt_secret.as_ref()), decoding: DecodingKey::from_secret(jwt_secret.as_ref()),
}; };
let allowed_origins = vec![ let allowed_origins = vec![
"http://82.66.253.209", "http://82.66.253.209",
"http://localhost:5173", "http://localhost:5173",
]; ];
let cors = CorsLayer::very_permissive() let cors = CorsLayer::very_permissive()
.allow_credentials(true) .allow_credentials(true)
.allow_methods([Method::GET, Method::POST, Method::PUT, Method::OPTIONS]) .allow_methods([Method::GET, Method::POST, Method::PUT, Method::OPTIONS])
.allow_headers([CONTENT_TYPE, AUTHORIZATION]); .allow_headers([CONTENT_TYPE, AUTHORIZATION]);
let app = create_router(state) let app = create_router(state)
.layer(Extension(jwt_keys)) .layer(Extension(jwt_keys))
.layer(cors); .layer(cors);
@@ -110,6 +110,6 @@ let cors = CorsLayer::very_permissive()
Ok(()) Ok(())
} }
async fn handler() -> &'static str { async fn handler() -> &'static str {
"Hiii from localhost" "Hiii from localhost"
} }

View File

@@ -83,7 +83,7 @@ pub async fn clean_db_update(
} }
if let Some(hotel_users) = state.ws_map.get(&hotel_id) { if let Some(hotel_users) = state.ws_map.get(&hotel_id) {
let update_msg = json!({ let update_msg = json!({
"event-type": "room-update", "event_type": "room_update",
"room_id": room_id, "room_id": room_id,
"status": payload.status, "status": payload.status,
"updated_by": user_id, "updated_by": user_id,

View File

@@ -47,6 +47,8 @@ pub async fn token_tester(
) )
} }
pub struct AuthUser(pub Claims); //?? pub struct AuthUser(pub Claims); //??
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@@ -56,44 +58,51 @@ pub struct AuthClaims {
//pub username: String, //pub username: String,
} }
pub fn auth_claims_from_token(
token: &str,
keys: &JwtKeys,
) -> Result<AuthClaims, (StatusCode, String)> {
let token_data = decode::<Claims>(
token,
&keys.decoding,
&Validation::new(Algorithm::HS256),
).map_err(|_| (StatusCode::UNAUTHORIZED, "Invalid token".into()))?;
Ok(AuthClaims {
user_id: token_data.claims.id,
hotel_id: token_data.claims.hotel_id,
})
}
impl<S> FromRequestParts<S> for AuthClaims impl<S> FromRequestParts<S> for AuthClaims
where where
S: Send + Sync + 'static, S: Send + Sync + 'static,
AppState: Clone + Send + Sync + 'static, AppState: FromRef<S> AppState: FromRef<S>,
{ {
type Rejection = (StatusCode, String); type Rejection = (StatusCode, String);
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> { async fn from_request_parts(
// We assume your state has a `jwt_secret` field parts: &mut Parts,
state: &S,
) -> Result<Self, Self::Rejection> {
let Extension(keys): Extension<JwtKeys> = let Extension(keys): Extension<JwtKeys> =
Extension::from_request_parts(parts, state).await.map_err(|_| (StatusCode::UNAUTHORIZED, "Missing keys".to_string()))?; Extension::from_request_parts(parts, state)
.await
.map_err(|_| (StatusCode::UNAUTHORIZED, "Missing keys".into()))?;
// 1⃣ Extract the token from the Authorization header
let auth_header = parts let auth_header = parts
.headers .headers
.get("Authorization") .get(axum::http::header::AUTHORIZATION)
.ok_or((StatusCode::UNAUTHORIZED, "Missing Authorization header".to_string()))? .ok_or((StatusCode::UNAUTHORIZED, "Missing Authorization header".into()))?
.to_str() .to_str()
.map_err(|_| (StatusCode::BAD_REQUEST, "Invalid Authorization header".to_string()))?; .map_err(|_| (StatusCode::BAD_REQUEST, "Invalid Authorization header".into()))?;
// Bearer token?
let token = auth_header let token = auth_header
.strip_prefix("Bearer ") .strip_prefix("Bearer ")
.ok_or((StatusCode::BAD_REQUEST, "Expected Bearer token".to_string()))?; .ok_or((StatusCode::BAD_REQUEST, "Expected Bearer token".into()))?;
// 2⃣ Decode the token
let token_data = decode::<Claims>(
token,
&keys.decoding,
&Validation::new(Algorithm::HS256),
).map_err(|_| (StatusCode::UNAUTHORIZED, "Invalid token".to_string()))?;
Ok(AuthClaims {
user_id: token_data.claims.id,
hotel_id: token_data.claims.hotel_id,
//username: token_data.claims.username,
})
auth_claims_from_token(token, &keys)
} }
} }

View File

@@ -16,7 +16,7 @@ pub fn utils_routes() -> Router<AppState> {
Router::new() Router::new()
.route("/login", put(clean_auth_loging)) .route("/login", put(clean_auth_loging))
.route("/register", put(register_user)) .route("/register", put(register_user))
.route("/ws/", get(ws_handler))
.route("/tokentest", put(token_tester)) .route("/tokentest", put(token_tester))
.route("/force_update_password", put(force_update_password)) .route("/force_update_password", put(force_update_password))
.route("/update_password", put(update_password)) .route("/update_password", put(update_password))
@@ -27,7 +27,7 @@ pub fn utils_routes() -> Router<AppState> {
.route("/logout_single_device", post(logout_from_single_device)) .route("/logout_single_device", post(logout_from_single_device))
.route("/logout_all_devices", post(logout_from_all_devices)) .route("/logout_all_devices", post(logout_from_all_devices))
.route("/ws/{req_token}", get(ws_handler))
//.with_state(state) //.with_state(state)
} }

View File

@@ -1,13 +1,14 @@
use dashmap::DashMap; use dashmap::DashMap;
use reqwest::StatusCode;
use std::sync::Arc; use std::sync::Arc;
use axum::extract::{ws::{Message, WebSocket, WebSocketUpgrade}, State}; use axum::{Extension, extract::{State, ws::{Message, WebSocket, WebSocketUpgrade}}};
use tokio::sync::mpsc; use tokio::sync::mpsc;
use axum::extract::Path; use axum::extract::Path;
use axum::response::IntoResponse; use axum::response::IntoResponse;
//use futures_util::stream::stream::StreamExt; //use futures_util::stream::stream::StreamExt;
use futures_util::{StreamExt, SinkExt}; use futures_util::{StreamExt, SinkExt};
use crate::utils::{auth::AuthClaims, db_pool::{AppState, HotelPool}}; use crate::utils::{auth::{AuthClaims, JwtKeys, auth_claims_from_token}, db_pool::{AppState, HotelPool}};
@@ -89,12 +90,37 @@ async fn handle_socket(
} }
pub async fn ws_handler( pub async fn ws_handler(
AuthClaims {user_id, hotel_id}: AuthClaims, //AuthClaims {user_id, hotel_id}: AuthClaims,
ws: WebSocketUpgrade, ws: WebSocketUpgrade,
Extension(keys): Extension<JwtKeys>,
State(state): State<AppState>, State(state): State<AppState>,
//Path((hotel_id, user_id)): Path<(i32, i32)>, Path((req_token)): Path<(String)>,
) -> impl IntoResponse { ) -> impl IntoResponse {
ws.on_upgrade(move |socket| handle_socket(socket, state, hotel_id, user_id))
let token = req_token;
let claims = match auth_claims_from_token(&token, &keys) {
Err(_) => {
print!("error during auth claims processing");
return StatusCode::UNAUTHORIZED.into_response();
}
Ok(c) => c
};
print!("{token}, web socket tried to connect", );
/*
let claims = match auth_claims_from_token(&token, &keys) {
Ok(c) => c,
Err(_) => return StatusCode::UNAUTHORIZED.into_response(),
};
*/
ws.on_upgrade(move |socket| handle_socket(socket, state, claims.hotel_id, claims.user_id))
} }
fn print_ws_state(state: &AppState) { fn print_ws_state(state: &AppState) {