refactoring token -> claim processing

This commit is contained in:
2026-01-03 17:15:34 +01:00
parent c0d70077d7
commit 170fedbcbd
12 changed files with 95 additions and 60 deletions

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -322,7 +322,7 @@ pub async fn send_message(
if let Some(hotel_users) = state.ws_map.get(&hotel_id) {
let update_msg = serde_json::json!({
"event-type": "chat-message",
"event_type": "chat_message",
"conv_id": payload.conv_id,
"sender": user_id,
"content": payload.message,

View File

@@ -77,7 +77,7 @@ pub async fn update_inventory_item(
(StatusCode::OK, format!("updated item history"))
}
Ok(_) => (StatusCode::NOT_FOUND, "No room found".to_string()),
Ok(_) => (StatusCode::NOT_FOUND, "No item found, err : {_}".to_string()),
Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, format!("Error from DB: {err}")),
}
@@ -104,7 +104,7 @@ pub async fn get_inventory_item(
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, "Pool error".to_string()),
};
let mut stmt = match conn.prepare("SELECT id, amount, item_name, user_id FROM inventory") {
let mut stmt = match conn.prepare("SELECT id, amount, item_name, user_id, updated_at FROM inventory") {
Ok(s) => s,
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, "Statement error".to_string()),
};
@@ -120,7 +120,7 @@ pub async fn get_inventory_item(
let item = InventoryItems {
id: row.get("id").unwrap_or_default(),
amount: row.get("amount").unwrap_or_default(),
name: row.get("name").unwrap_or_default(),
name: row.get("item_name").unwrap_or_default(),
user_id: row.get("user_id").unwrap_or_default(),
updated_at: row.get("updated_at").unwrap_or_default(),
};

View File

@@ -83,7 +83,7 @@ pub async fn clean_db_update(
}
if let Some(hotel_users) = state.ws_map.get(&hotel_id) {
let update_msg = json!({
"event-type": "room-update",
"event_type": "room_update",
"room_id": room_id,
"status": payload.status,
"updated_by": user_id,

View File

@@ -47,6 +47,8 @@ pub async fn token_tester(
)
}
pub struct AuthUser(pub Claims); //??
#[derive(Debug, Clone)]
@@ -56,44 +58,51 @@ pub struct AuthClaims {
//pub username: String,
}
impl<S> FromRequestParts<S> for AuthClaims
where
S: Send + Sync + 'static,
AppState: Clone + Send + Sync + 'static, AppState: FromRef<S>
{
type Rejection = (StatusCode, String);
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
// We assume your state has a `jwt_secret` field
let Extension(keys): Extension<JwtKeys> =
Extension::from_request_parts(parts, state).await.map_err(|_| (StatusCode::UNAUTHORIZED, "Missing keys".to_string()))?;
// 1⃣ Extract the token from the Authorization header
let auth_header = parts
.headers
.get("Authorization")
.ok_or((StatusCode::UNAUTHORIZED, "Missing Authorization header".to_string()))?
.to_str()
.map_err(|_| (StatusCode::BAD_REQUEST, "Invalid Authorization header".to_string()))?;
// Bearer token?
let token = auth_header
.strip_prefix("Bearer ")
.ok_or((StatusCode::BAD_REQUEST, "Expected Bearer token".to_string()))?;
// 2⃣ Decode the token
pub fn auth_claims_from_token(
token: &str,
keys: &JwtKeys,
) -> Result<AuthClaims, (StatusCode, String)> {
let token_data = decode::<Claims>(
token,
&keys.decoding,
&Validation::new(Algorithm::HS256),
).map_err(|_| (StatusCode::UNAUTHORIZED, "Invalid token".to_string()))?;
).map_err(|_| (StatusCode::UNAUTHORIZED, "Invalid token".into()))?;
Ok(AuthClaims {
user_id: token_data.claims.id,
hotel_id: token_data.claims.hotel_id,
//username: token_data.claims.username,
})
}
impl<S> FromRequestParts<S> for AuthClaims
where
S: Send + Sync + 'static,
AppState: FromRef<S>,
{
type Rejection = (StatusCode, String);
async fn from_request_parts(
parts: &mut Parts,
state: &S,
) -> Result<Self, Self::Rejection> {
let Extension(keys): Extension<JwtKeys> =
Extension::from_request_parts(parts, state)
.await
.map_err(|_| (StatusCode::UNAUTHORIZED, "Missing keys".into()))?;
let auth_header = parts
.headers
.get(axum::http::header::AUTHORIZATION)
.ok_or((StatusCode::UNAUTHORIZED, "Missing Authorization header".into()))?
.to_str()
.map_err(|_| (StatusCode::BAD_REQUEST, "Invalid Authorization header".into()))?;
let token = auth_header
.strip_prefix("Bearer ")
.ok_or((StatusCode::BAD_REQUEST, "Expected Bearer token".into()))?;
auth_claims_from_token(token, &keys)
}
}

View File

@@ -16,7 +16,7 @@ pub fn utils_routes() -> Router<AppState> {
Router::new()
.route("/login", put(clean_auth_loging))
.route("/register", put(register_user))
.route("/ws/", get(ws_handler))
.route("/tokentest", put(token_tester))
.route("/force_update_password", put(force_update_password))
.route("/update_password", put(update_password))
@@ -27,7 +27,7 @@ pub fn utils_routes() -> Router<AppState> {
.route("/logout_single_device", post(logout_from_single_device))
.route("/logout_all_devices", post(logout_from_all_devices))
.route("/ws/{req_token}", get(ws_handler))
//.with_state(state)
}

View File

@@ -1,13 +1,14 @@
use dashmap::DashMap;
use reqwest::StatusCode;
use std::sync::Arc;
use axum::extract::{ws::{Message, WebSocket, WebSocketUpgrade}, State};
use axum::{Extension, extract::{State, ws::{Message, WebSocket, WebSocketUpgrade}}};
use tokio::sync::mpsc;
use axum::extract::Path;
use axum::response::IntoResponse;
//use futures_util::stream::stream::StreamExt;
use futures_util::{StreamExt, SinkExt};
use crate::utils::{auth::AuthClaims, db_pool::{AppState, HotelPool}};
use crate::utils::{auth::{AuthClaims, JwtKeys, auth_claims_from_token}, db_pool::{AppState, HotelPool}};
@@ -89,12 +90,37 @@ async fn handle_socket(
}
pub async fn ws_handler(
AuthClaims {user_id, hotel_id}: AuthClaims,
//AuthClaims {user_id, hotel_id}: AuthClaims,
ws: WebSocketUpgrade,
Extension(keys): Extension<JwtKeys>,
State(state): State<AppState>,
//Path((hotel_id, user_id)): Path<(i32, i32)>,
Path((req_token)): Path<(String)>,
) -> impl IntoResponse {
ws.on_upgrade(move |socket| handle_socket(socket, state, hotel_id, user_id))
let token = req_token;
let claims = match auth_claims_from_token(&token, &keys) {
Err(_) => {
print!("error during auth claims processing");
return StatusCode::UNAUTHORIZED.into_response();
}
Ok(c) => c
};
print!("{token}, web socket tried to connect", );
/*
let claims = match auth_claims_from_token(&token, &keys) {
Ok(c) => c,
Err(_) => return StatusCode::UNAUTHORIZED.into_response(),
};
*/
ws.on_upgrade(move |socket| handle_socket(socket, state, claims.hotel_id, claims.user_id))
}
fn print_ws_state(state: &AppState) {