diff --git a/src/main.rs b/src/main.rs index 37eb29e..60b150a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,4 +1,6 @@ use axum::serve; +use axum::Extension; +use jsonwebtoken::{DecodingKey, EncodingKey}; use tokio::net::TcpListener; mod utils; @@ -8,7 +10,7 @@ use r2d2::{Pool}; use r2d2_sqlite::SqliteConnectionManager; use crate::utils::db_pool::{HotelPool,AppState}; use routes::create_router; - +use crate::utils::auth::JwtKeys; @@ -25,11 +27,18 @@ async fn main() -> std::io::Result<()> { let state = AppState { hotel_pools, logs_pool, - jwt_secret: "your_jwt_secret_key s".to_string(), // better: load from env var + //jwt_secret: "your_jwt_secret_key s".to_string(), // better: load from env var }; + let jwt_secret = "your_jwt_secret_key".to_string(); + let jwt_keys = JwtKeys { + encoding: EncodingKey::from_secret(jwt_secret.as_ref()), + decoding: DecodingKey::from_secret(jwt_secret.as_ref()), + }; + + let app = create_router(state) + .layer(Extension(jwt_keys)); - let app = create_router(state); let listener = TcpListener::bind("0.0.0.0:3000").await?; serve(listener, app).into_future().await?; Ok(()) diff --git a/src/utils/auth.rs b/src/utils/auth.rs index dce120f..160dfc5 100644 --- a/src/utils/auth.rs +++ b/src/utils/auth.rs @@ -6,7 +6,7 @@ use axum::{ middleware::Next, response::{Response, IntoResponse}, Json, - extract::{Path, State, FromRequest, FromRequestParts} + extract::{Path, State, FromRequest, FromRequestParts, Extension} }; use axum::extract::FromRef; use axum::extract::Request as ExtractRequest; @@ -15,7 +15,6 @@ use serde::{Deserialize, Serialize}; use serde_json::Value; use chrono::{Utc}; use rusqlite::{params, Connection, OptionalExtension}; -use async_trait::async_trait; use rand_core::OsRng; use argon2::{ @@ -25,8 +24,15 @@ use argon2::{ //use crate::utils::db_pool::; use crate::utils::db_pool::{HotelPool,AppState}; +#[derive(Clone)] +pub struct JwtKeys { + pub encoding: EncodingKey, + pub decoding: DecodingKey, +} + pub async fn token_tester( State(state): State, + Extension(keys): Extension, AuthClaims { user_id, hotel_id, username }: AuthClaims, ) -> impl IntoResponse { format!( @@ -53,7 +59,9 @@ where async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { // We assume your state has a `jwt_secret` field - + let Extension(keys): Extension = + Extension::from_request_parts(parts, state).await.map_err(|_| (StatusCode::UNAUTHORIZED, "Missing keys".to_string()))?; + // 1️⃣ Extract the token from the Authorization header let auth_header = parts .headers @@ -70,7 +78,7 @@ where // 2️⃣ Decode the token let token_data = decode::( token, - &DecodingKey::from_secret("your_jwt_secret_key s".to_string().as_ref()), + &keys.decoding, &Validation::new(Algorithm::HS256), ).map_err(|_| (StatusCode::UNAUTHORIZED, "Invalid token".to_string()))?; @@ -186,14 +194,15 @@ struct LoginResponse { pub async fn clean_auth_loging( State(state): State, - LoginPayload(payload): LoginPayload + Extension(keys): Extension, + LoginPayload(payload): LoginPayload, ) -> impl IntoResponse { // 1️⃣ Get a connection from logs pool let conn = match state.logs_pool.get() { Ok(c) => c, Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, "DB connection error").into_response(), }; - + let user_row = match conn.query_row( "SELECT id, password, hotel_id, displayname FROM users WHERE username = ?1", params![&payload.username], @@ -234,7 +243,7 @@ pub async fn clean_auth_loging( let token = match encode( &Header::default(), &claims, - &EncodingKey::from_secret(state.jwt_secret.as_ref()), + &keys.encoding ) { Ok(t) => t, Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, "JWT creation failed").into_response(), diff --git a/src/utils/db_pool.rs b/src/utils/db_pool.rs index db12556..7d1d5f8 100644 --- a/src/utils/db_pool.rs +++ b/src/utils/db_pool.rs @@ -9,7 +9,7 @@ type HotelId = i32; // or i32 if you want numeric ids pub struct AppState { pub hotel_pools: HotelPool, pub logs_pool: Pool, - pub jwt_secret: String + }