use std::time::Duration; use axum::{ body::{to_bytes, Body}, extract::{Extension, FromRequest, FromRequestParts, Path, State}, http::{header::{HeaderValue, SET_COOKIE}, request::Parts, Request as HttpRequest, StatusCode }, middleware::Next, response::{IntoResponse, IntoResponseParts, Response}, Json }; use axum_extra::extract::TypedHeader; //use axum_extra::TypedHeader; use headers::UserAgent; use axum::extract::FromRef; use axum::extract::Request as ExtractRequest; use jsonwebtoken::{decode, DecodingKey, Validation, encode, EncodingKey, Header, Algorithm}; use serde::{Deserialize, Serialize}; use serde_json::Value; use chrono::{Utc}; use rusqlite::{params, Connection, OptionalExtension}; use rand_core::{RngCore, OsRng}; use argon2::{ password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString}, Argon2}; use uuid::Uuid; //use crate::utils::db_pool::; use crate::utils::db_pool::{HotelPool,AppState}; use base64::{engine::general_purpose, Engine as _}; #[derive(Clone)] pub struct JwtKeys { pub encoding: EncodingKey, pub decoding: DecodingKey, } pub async fn token_tester( State(state): State, //Extension(keys): Extension, AuthClaims { user_id, hotel_id }: AuthClaims, ) -> impl IntoResponse { format!( "(user_id: {}) from hotel {}", user_id, hotel_id ) } pub struct AuthUser(pub Claims); //?? #[derive(Debug, Clone)] pub struct AuthClaims { pub user_id: i32, pub hotel_id: i32, //pub username: String, } impl FromRequestParts for AuthClaims where S: Send + Sync + 'static, AppState: Clone + Send + Sync + 'static, AppState: FromRef { type Rejection = (StatusCode, String); async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { // We assume your state has a `jwt_secret` field let Extension(keys): Extension = Extension::from_request_parts(parts, state).await.map_err(|_| (StatusCode::UNAUTHORIZED, "Missing keys".to_string()))?; // 1️⃣ Extract the token from the Authorization header let auth_header = parts .headers .get("Authorization") .ok_or((StatusCode::UNAUTHORIZED, "Missing Authorization header".to_string()))? .to_str() .map_err(|_| (StatusCode::BAD_REQUEST, "Invalid Authorization header".to_string()))?; // Bearer token? let token = auth_header .strip_prefix("Bearer ") .ok_or((StatusCode::BAD_REQUEST, "Expected Bearer token".to_string()))?; // 2️⃣ Decode the token let token_data = decode::( token, &keys.decoding, &Validation::new(Algorithm::HS256), ).map_err(|_| (StatusCode::UNAUTHORIZED, "Invalid token".to_string()))?; Ok(AuthClaims { user_id: token_data.claims.id, hotel_id: token_data.claims.hotel_id, //username: token_data.claims.username, }) } } // Hash a new password fn hash_password(password: &str) -> anyhow::Result { let salt = SaltString::generate(&mut OsRng); // unique per password let argon2 = Argon2::default(); // Argon2id with good defaults let password_hash = argon2 .hash_password(password.as_bytes(), &salt) .map_err(|e| anyhow::anyhow!(e))? .to_string(); Ok(password_hash) } // Verify an incoming password against stored hash fn verify_password(password: &str, stored_hash: &str) -> bool { let parsed_hash = match PasswordHash::new(&stored_hash) { Ok(hash) => hash, Err(_) => return false, }; Argon2::default() .verify_password(password.as_bytes(), &parsed_hash).is_ok() } #[derive(Deserialize, Debug)] pub struct RegisterValues{ username: String, password: String, hotel_id: i32, displayname: String, } pub struct RegisterPayload(pub RegisterValues); impl FromRequest for RegisterPayload where S: Send + Sync, { type Rejection = (StatusCode, String); async fn from_request(req: ExtractRequest, state: &S) -> Result { let Json(payload) = Json::::from_request(req, state) .await .map_err(|err| (StatusCode::BAD_REQUEST, format!("Invalid body: {}", err)))?; Ok(RegisterPayload(payload)) } } pub async fn register_user ( State(state): State, RegisterPayload(payload): RegisterPayload ) -> Result { let hashed_password = hash_password(&payload.password) .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "Password hashing failed"))?; let conn = state.logs_pool.get() .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "DB connection error"))?; conn.execute( "INSERT INTO users (username, password, hotel_id, displayname) VALUES (?1, ?2, ?3, ?4)", params![payload.username, hashed_password, payload.hotel_id, payload.displayname], ) .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "DB insert error"))?; Ok((StatusCode::CREATED, "User registered successfully")) } #[derive(Serialize, Deserialize, Debug)] pub struct ForceUpdatePasswordValues{ username: String, newpassword: String, hotel_id: i32, admin_pass: String, } //pub struct ForceUpdatePasswordPayload (pub ForceUpdatePasswordValues); pub async fn ForceUpdatePassword( State(state): State, Json(payload): Json, ) -> impl IntoResponse { let conn = match state.logs_pool.get() { Ok(c) => c, Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, "DB, conn failed").into_response() }; let user_row = match conn.query_row( "SELECT id FROM users WHERE username = ?1 AND hotel_id = ?2", params![&payload.username, &payload.hotel_id], |row|{ let user_id: i32 = row.get(0)?; //let hotel_id: i32 = row.get(1)?; Ok((user_id)) }, ).optional() { Ok(opt) => opt, Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, "DB query error") .into_response(), }; let (user_id) = match user_row { Some(u) => u, None => return (StatusCode::UNAUTHORIZED, "Not correct user") .into_response(), }; let admin_check: String = "my_admin_password".to_string(); if &payload.admin_pass != &admin_check { return (StatusCode::UNAUTHORIZED, "Invalid Amin Password").into_response() }; let hashed_password = match hash_password(&payload.newpassword) { Ok(h) => h, Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, "Password hashing failed").into_response(), }; let result = conn.execute( "UPDATE users SET password = ?1 WHERE id = ?2", params![&hashed_password, &user_id], ); match result { Ok(rows) if rows > 0 => (StatusCode::OK, "Password updated").into_response(), Ok(_) => (StatusCode::NOT_FOUND, "User not found").into_response(), Err(_) => (StatusCode::INTERNAL_SERVER_ERROR, "Failed to update password").into_response(), } } #[derive(Serialize, Deserialize, Debug)] pub struct UpdatePasswordValues{ username: String, current_password: String, newpassword: String, hotel_id: i32, } pub async fn UpdatePassword( State(state): State, Json(payload): Json, ) -> impl IntoResponse { let conn = match state.logs_pool.get() { Ok(c) => c, Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, "DB, conn failed").into_response() }; let user_row = match conn.query_row( "SELECT password, id FROM users WHERE username = ?1 AND current_password = ?2", params![&payload.username, &payload.current_password], |row|{ let password: String = row.get(0)?; let id: i32 = row.get(1)?; Ok((password, id)) }, ).optional() { Ok(opt) => opt, Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, "DB query error") .into_response(), }; let (password, user_id) = match user_row { Some(u) => u, None => return (StatusCode::UNAUTHORIZED, "Not correct user") .into_response(), }; if verify_password( &payload.current_password, &password ) { return (StatusCode::UNAUTHORIZED, "Invalid Password").into_response() }; let hashed_password = match hash_password(&payload.newpassword) { Ok(h) => h, Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, "Password hashing failed").into_response(), }; let result = conn.execute( "UPDATE users SET password = ?1 WHERE id = ?2", params![&hashed_password, &user_id], ); match result { Ok(rows) if rows > 0 => (StatusCode::OK, "Password updated").into_response(), Ok(_) => (StatusCode::NOT_FOUND, "User not found").into_response(), Err(_) => (StatusCode::INTERNAL_SERVER_ERROR, "Failed to update password").into_response(), } } #[derive(Deserialize, Debug)] pub struct LoginValues { username : String, password : String, hotel_id: i32, } pub struct LoginPayload(pub LoginValues); impl FromRequest for LoginPayload where S: Send + Sync, { type Rejection = (StatusCode, String); async fn from_request(req: ExtractRequest, state: &S) -> Result { let Json(payload) = Json::::from_request(req, state) .await .map_err(|err| (StatusCode::BAD_REQUEST, format!("Invalid body: {}", err)))?; Ok(LoginPayload(payload)) } } #[derive(Deserialize,Debug, Serialize, Clone)] struct Claims{ id: i32, hotel_id: i32, //display_name username: String, exp: usize, } #[derive(Serialize)] struct LoginResponse { token: String, } pub async fn clean_auth_loging( State(state): State, Extension(keys): Extension, LoginPayload(payload): LoginPayload, ) -> impl IntoResponse { // 1️⃣ Get a connection from logs pool let conn = match state.logs_pool.get() { Ok(c) => c, Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, "DB connection error").into_response(), }; let user_row = match conn.query_row( "SELECT id, password, hotel_id, displayname FROM users WHERE username = ?1", params![&payload.username], |row| { let user_id: i32 = row.get(0)?; let password: String = row.get(1)?; let hotel_id: i32 = row.get(2)?; let displayname: String = row.get(3)?; Ok((user_id, password, hotel_id, displayname)) }, ).optional() { Ok(opt) => opt, Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, "DB query error").into_response(), }; let (user_id, stored_hash, hotel_id, displayname) = match user_row { Some(u) => u, None => return (StatusCode::UNAUTHORIZED, "Invalid credentials").into_response(), }; if !verify_password(&payload.password, &stored_hash) { return (StatusCode::UNAUTHORIZED, "Invelid credentials").into_response(); } let expiration = match chrono::Utc::now().checked_add_signed(chrono::Duration::hours(15)) { Some(time) => time.timestamp() as usize, None => { // Handle overflow — probably a 500, since this should never happen return (StatusCode::INTERNAL_SERVER_ERROR, "Time overflow".to_string()).into_response(); } }; let claims = serde_json::json!({ "id": user_id, "hotel_id": hotel_id, "username": payload.username, "exp": expiration }); let token = match encode( &Header::default(), &claims, &keys.encoding ) { Ok(t) => t, Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, "JWT creation failed").into_response(), }; Json(LoginResponse { token }).into_response() } #[derive(Deserialize, Debug)] pub struct CreateRefreshTokenValue { pub username: String, pub password: String, pub device_id: Uuid, //pub timestamp: Option, } #[axum::debug_handler] pub async fn create_refresh_token( State(state): State, user_agent: Option>, Json(payload): Json ) -> Result { // ← Add Result here let user_agent_str = user_agent .map(|ua| ua.to_string()) .unwrap_or_else(|| "Unknown".to_string()); let device_id_str = payload.device_id.to_string(); let argon2 = Argon2::default(); let salt = SaltString::generate(&mut OsRng); let mut bytes = [0u8; 64]; OsRng.fill_bytes(&mut bytes); let raw_token = Uuid::new_v4().to_string(); let hashed_token = argon2 .hash_password(raw_token.as_bytes(), &salt) .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))? .to_string(); let conn = state.logs_pool.get() .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "DB connection error".to_string()))?; let user_row = conn.query_row( "SELECT id, password, hotel_id FROM users WHERE username = ?1", params![&payload.username], |row| { let user_id: i32 = row.get(0)?; let password: String = row.get(1)?; let hotel_id: i32 = row.get(2)?; Ok((user_id, password, hotel_id)) }, ).optional() .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "DB get user id error".to_string()))?; let (user_id, stored_hash, hotel_id) = user_row .ok_or((StatusCode::NOT_FOUND, "User not found".to_string()))?; if !verify_password(&payload.password, &stored_hash) { return Err((StatusCode::UNAUTHORIZED, "Invalid credentials".to_string())); } conn.execute( "INSERT INTO refresh_token (user_id, token_hash, device_id, user_agent, hotel_id) VALUES (?1, ?2, ?3, ?4, ?5)", params![user_id, hashed_token, device_id_str, user_agent_str, hotel_id], ) .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "DB insert error".to_string()))?; let cookie_value = format!("refresh_token={}; HttpOnly; Secure; Path=/", raw_token); let mut response = (StatusCode::CREATED, "Refresh token created successfully").into_response(); response.headers_mut().insert( SET_COOKIE, HeaderValue::from_str(&cookie_value).unwrap(), ); Ok(response) // ← Wrap in Ok() } #[derive(Deserialize)] pub struct LoginRefreshTokenValues{ device_id: Uuid, refresh_token: String, } pub async fn login_refresh_token ( State(state): State, Extension(keys): Extension, user_agent: Option>, Json(payload): Json ) -> impl IntoResponse { let conn = match state.logs_pool.get() { Ok(c) => c, Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, "DB connection error").into_response(), }; let user_agent_str = user_agent .map(|ua| ua.to_string()) .unwrap_or_else(|| "Unknown".to_string()); let device_id_str = payload.device_id.to_string(); //"SELECT user_id, token_hash, hotel_id FROM refresh_token WHERE device_id = ?1 AND user_agent = ?2", let device_row = match conn.query_row( "SELECT user_id, token_hash, hotel_id FROM refresh_token WHERE device_id = ?1 AND user_agent = ?2", params![&device_id_str, &user_agent_str], |row| { let user_id: i32 = row.get(0)?; let token_hash: String = row.get(1)?; let hotel_id: i32 = row.get(2)?; //let displayname: String = row.get(3)?; Ok((user_id, token_hash, hotel_id)) }, ).optional() { Ok(opt) => opt, Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, "DB query error").into_response(), }; let (user_id, token_hash, hotel_id) = match device_row { Some(tuple) => tuple, None => return (StatusCode::UNAUTHORIZED, "No matching device").into_response(), }; if !verify_password(&payload.refresh_token, &token_hash) { return (StatusCode::UNAUTHORIZED, "Invalid or mismatched token").into_response(); } let expiration = match chrono::Utc::now().checked_add_signed(chrono::Duration::hours(15)) { Some(time) => time.timestamp() as usize, None => { // Handle overflow — probably a 500, since this should never happen return (StatusCode::INTERNAL_SERVER_ERROR, "Time overflow".to_string()).into_response(); } }; let claims = serde_json::json!({ "id": user_id, "hotel_id": hotel_id, //"username": payload.username, "exp": expiration }); let token = match encode( &Header::default(), &claims, &keys.encoding ) { Ok(t) => t, Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, "JWT creation failed").into_response(), }; Json(LoginResponse { token }).into_response() } fn internal_error(err: E) -> (StatusCode, String) { (StatusCode::INTERNAL_SERVER_ERROR, format!("Internal error: {}", err)) }