refactoring token -> claim processing

This commit is contained in:
2026-01-03 17:15:34 +01:00
parent c0d70077d7
commit 170fedbcbd
12 changed files with 95 additions and 60 deletions

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -322,7 +322,7 @@ pub async fn send_message(
if let Some(hotel_users) = state.ws_map.get(&hotel_id) { if let Some(hotel_users) = state.ws_map.get(&hotel_id) {
let update_msg = serde_json::json!({ let update_msg = serde_json::json!({
"event-type": "chat-message", "event_type": "chat_message",
"conv_id": payload.conv_id, "conv_id": payload.conv_id,
"sender": user_id, "sender": user_id,
"content": payload.message, "content": payload.message,

View File

@@ -77,7 +77,7 @@ pub async fn update_inventory_item(
(StatusCode::OK, format!("updated item history")) (StatusCode::OK, format!("updated item history"))
} }
Ok(_) => (StatusCode::NOT_FOUND, "No room found".to_string()), Ok(_) => (StatusCode::NOT_FOUND, "No item found, err : {_}".to_string()),
Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, format!("Error from DB: {err}")), Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, format!("Error from DB: {err}")),
} }
@@ -104,7 +104,7 @@ pub async fn get_inventory_item(
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, "Pool error".to_string()), Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, "Pool error".to_string()),
}; };
let mut stmt = match conn.prepare("SELECT id, amount, item_name, user_id FROM inventory") { let mut stmt = match conn.prepare("SELECT id, amount, item_name, user_id, updated_at FROM inventory") {
Ok(s) => s, Ok(s) => s,
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, "Statement error".to_string()), Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, "Statement error".to_string()),
}; };
@@ -120,7 +120,7 @@ pub async fn get_inventory_item(
let item = InventoryItems { let item = InventoryItems {
id: row.get("id").unwrap_or_default(), id: row.get("id").unwrap_or_default(),
amount: row.get("amount").unwrap_or_default(), amount: row.get("amount").unwrap_or_default(),
name: row.get("name").unwrap_or_default(), name: row.get("item_name").unwrap_or_default(),
user_id: row.get("user_id").unwrap_or_default(), user_id: row.get("user_id").unwrap_or_default(),
updated_at: row.get("updated_at").unwrap_or_default(), updated_at: row.get("updated_at").unwrap_or_default(),
}; };

View File

@@ -50,7 +50,7 @@ async fn main() -> std::io::Result<()> {
dotenv().ok(); dotenv().ok();
std::panic::set_hook(Box::new(|info| { std::panic::set_hook(Box::new(|info| {
let msg = format!("Rust panic: {}", info); let msg = format!("Rust panic: {}", info);
// Use blocking client so the process can't exit before sending // Use blocking client so the process can't exit before sending
@@ -63,9 +63,9 @@ std::panic::set_hook(Box::new(|info| {
.post("https://discord.com/api/webhooks/1440912618205347891/Ekg89krDoPm41kA27LA3gXgNWmMWvCCtziYIUsjqaY22Jnw4a6IWhZOht0in5JjnPX-W") .post("https://discord.com/api/webhooks/1440912618205347891/Ekg89krDoPm41kA27LA3gXgNWmMWvCCtziYIUsjqaY22Jnw4a6IWhZOht0in5JjnPX-W")
.json(&payload) .json(&payload)
.send(); .send();
})); }));
//panic!("crash-test"); //panic!("crash-test");
let hotel_pools = HotelPool::new(); let hotel_pools = HotelPool::new();
let logs_manager = SqliteConnectionManager::file("db/auth_copy_2.sqlite"); let logs_manager = SqliteConnectionManager::file("db/auth_copy_2.sqlite");
@@ -92,12 +92,12 @@ std::panic::set_hook(Box::new(|info| {
decoding: DecodingKey::from_secret(jwt_secret.as_ref()), decoding: DecodingKey::from_secret(jwt_secret.as_ref()),
}; };
let allowed_origins = vec![ let allowed_origins = vec![
"http://82.66.253.209", "http://82.66.253.209",
"http://localhost:5173", "http://localhost:5173",
]; ];
let cors = CorsLayer::very_permissive() let cors = CorsLayer::very_permissive()
.allow_credentials(true) .allow_credentials(true)
.allow_methods([Method::GET, Method::POST, Method::PUT, Method::OPTIONS]) .allow_methods([Method::GET, Method::POST, Method::PUT, Method::OPTIONS])
.allow_headers([CONTENT_TYPE, AUTHORIZATION]); .allow_headers([CONTENT_TYPE, AUTHORIZATION]);
@@ -110,6 +110,6 @@ let cors = CorsLayer::very_permissive()
Ok(()) Ok(())
} }
async fn handler() -> &'static str { async fn handler() -> &'static str {
"Hiii from localhost" "Hiii from localhost"
} }

View File

@@ -83,7 +83,7 @@ pub async fn clean_db_update(
} }
if let Some(hotel_users) = state.ws_map.get(&hotel_id) { if let Some(hotel_users) = state.ws_map.get(&hotel_id) {
let update_msg = json!({ let update_msg = json!({
"event-type": "room-update", "event_type": "room_update",
"room_id": room_id, "room_id": room_id,
"status": payload.status, "status": payload.status,
"updated_by": user_id, "updated_by": user_id,

View File

@@ -47,6 +47,8 @@ pub async fn token_tester(
) )
} }
pub struct AuthUser(pub Claims); //?? pub struct AuthUser(pub Claims); //??
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@@ -56,44 +58,51 @@ pub struct AuthClaims {
//pub username: String, //pub username: String,
} }
impl<S> FromRequestParts<S> for AuthClaims pub fn auth_claims_from_token(
where token: &str,
S: Send + Sync + 'static, keys: &JwtKeys,
AppState: Clone + Send + Sync + 'static, AppState: FromRef<S> ) -> Result<AuthClaims, (StatusCode, String)> {
{
type Rejection = (StatusCode, String);
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
// We assume your state has a `jwt_secret` field
let Extension(keys): Extension<JwtKeys> =
Extension::from_request_parts(parts, state).await.map_err(|_| (StatusCode::UNAUTHORIZED, "Missing keys".to_string()))?;
// 1⃣ Extract the token from the Authorization header
let auth_header = parts
.headers
.get("Authorization")
.ok_or((StatusCode::UNAUTHORIZED, "Missing Authorization header".to_string()))?
.to_str()
.map_err(|_| (StatusCode::BAD_REQUEST, "Invalid Authorization header".to_string()))?;
// Bearer token?
let token = auth_header
.strip_prefix("Bearer ")
.ok_or((StatusCode::BAD_REQUEST, "Expected Bearer token".to_string()))?;
// 2⃣ Decode the token
let token_data = decode::<Claims>( let token_data = decode::<Claims>(
token, token,
&keys.decoding, &keys.decoding,
&Validation::new(Algorithm::HS256), &Validation::new(Algorithm::HS256),
).map_err(|_| (StatusCode::UNAUTHORIZED, "Invalid token".to_string()))?; ).map_err(|_| (StatusCode::UNAUTHORIZED, "Invalid token".into()))?;
Ok(AuthClaims { Ok(AuthClaims {
user_id: token_data.claims.id, user_id: token_data.claims.id,
hotel_id: token_data.claims.hotel_id, hotel_id: token_data.claims.hotel_id,
//username: token_data.claims.username,
}) })
}
impl<S> FromRequestParts<S> for AuthClaims
where
S: Send + Sync + 'static,
AppState: FromRef<S>,
{
type Rejection = (StatusCode, String);
async fn from_request_parts(
parts: &mut Parts,
state: &S,
) -> Result<Self, Self::Rejection> {
let Extension(keys): Extension<JwtKeys> =
Extension::from_request_parts(parts, state)
.await
.map_err(|_| (StatusCode::UNAUTHORIZED, "Missing keys".into()))?;
let auth_header = parts
.headers
.get(axum::http::header::AUTHORIZATION)
.ok_or((StatusCode::UNAUTHORIZED, "Missing Authorization header".into()))?
.to_str()
.map_err(|_| (StatusCode::BAD_REQUEST, "Invalid Authorization header".into()))?;
let token = auth_header
.strip_prefix("Bearer ")
.ok_or((StatusCode::BAD_REQUEST, "Expected Bearer token".into()))?;
auth_claims_from_token(token, &keys)
} }
} }

View File

@@ -16,7 +16,7 @@ pub fn utils_routes() -> Router<AppState> {
Router::new() Router::new()
.route("/login", put(clean_auth_loging)) .route("/login", put(clean_auth_loging))
.route("/register", put(register_user)) .route("/register", put(register_user))
.route("/ws/", get(ws_handler))
.route("/tokentest", put(token_tester)) .route("/tokentest", put(token_tester))
.route("/force_update_password", put(force_update_password)) .route("/force_update_password", put(force_update_password))
.route("/update_password", put(update_password)) .route("/update_password", put(update_password))
@@ -27,7 +27,7 @@ pub fn utils_routes() -> Router<AppState> {
.route("/logout_single_device", post(logout_from_single_device)) .route("/logout_single_device", post(logout_from_single_device))
.route("/logout_all_devices", post(logout_from_all_devices)) .route("/logout_all_devices", post(logout_from_all_devices))
.route("/ws/{req_token}", get(ws_handler))
//.with_state(state) //.with_state(state)
} }

View File

@@ -1,13 +1,14 @@
use dashmap::DashMap; use dashmap::DashMap;
use reqwest::StatusCode;
use std::sync::Arc; use std::sync::Arc;
use axum::extract::{ws::{Message, WebSocket, WebSocketUpgrade}, State}; use axum::{Extension, extract::{State, ws::{Message, WebSocket, WebSocketUpgrade}}};
use tokio::sync::mpsc; use tokio::sync::mpsc;
use axum::extract::Path; use axum::extract::Path;
use axum::response::IntoResponse; use axum::response::IntoResponse;
//use futures_util::stream::stream::StreamExt; //use futures_util::stream::stream::StreamExt;
use futures_util::{StreamExt, SinkExt}; use futures_util::{StreamExt, SinkExt};
use crate::utils::{auth::AuthClaims, db_pool::{AppState, HotelPool}}; use crate::utils::{auth::{AuthClaims, JwtKeys, auth_claims_from_token}, db_pool::{AppState, HotelPool}};
@@ -89,12 +90,37 @@ async fn handle_socket(
} }
pub async fn ws_handler( pub async fn ws_handler(
AuthClaims {user_id, hotel_id}: AuthClaims, //AuthClaims {user_id, hotel_id}: AuthClaims,
ws: WebSocketUpgrade, ws: WebSocketUpgrade,
Extension(keys): Extension<JwtKeys>,
State(state): State<AppState>, State(state): State<AppState>,
//Path((hotel_id, user_id)): Path<(i32, i32)>, Path((req_token)): Path<(String)>,
) -> impl IntoResponse { ) -> impl IntoResponse {
ws.on_upgrade(move |socket| handle_socket(socket, state, hotel_id, user_id))
let token = req_token;
let claims = match auth_claims_from_token(&token, &keys) {
Err(_) => {
print!("error during auth claims processing");
return StatusCode::UNAUTHORIZED.into_response();
}
Ok(c) => c
};
print!("{token}, web socket tried to connect", );
/*
let claims = match auth_claims_from_token(&token, &keys) {
Ok(c) => c,
Err(_) => return StatusCode::UNAUTHORIZED.into_response(),
};
*/
ws.on_upgrade(move |socket| handle_socket(socket, state, claims.hotel_id, claims.user_id))
} }
fn print_ws_state(state: &AppState) { fn print_ws_state(state: &AppState) {