Files
hotel_api/src/utils/auth.rs

548 lines
17 KiB
Rust
Raw Blame History

This file contains invisible Unicode characters
This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
use std::time::Duration;
use axum::{
body::{to_bytes, Body}, extract::{Extension, FromRequest, FromRequestParts, Path, State}, http::{header::{HeaderValue, SET_COOKIE}, request::Parts, Request as HttpRequest, StatusCode }, middleware::Next, response::{IntoResponse, IntoResponseParts, Response}, Json
};
use axum_extra::extract::TypedHeader;
//use axum_extra::TypedHeader;
use headers::UserAgent;
use axum::extract::FromRef;
use axum::extract::Request as ExtractRequest;
use jsonwebtoken::{decode, DecodingKey, Validation, encode, EncodingKey, Header, Algorithm};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use chrono::{Utc};
use rusqlite::{params, Connection, OptionalExtension};
use rand_core::{RngCore, OsRng};
use argon2::{
password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString},
Argon2};
use uuid::Uuid;
//use crate::utils::db_pool::;
use crate::utils::db_pool::{HotelPool,AppState};
use base64::{engine::general_purpose, Engine as _};
#[derive(Clone)]
pub struct JwtKeys {
pub encoding: EncodingKey,
pub decoding: DecodingKey,
}
pub async fn token_tester(
State(state): State<AppState>,
//Extension(keys): Extension<JwtKeys>,
AuthClaims { user_id, hotel_id }: AuthClaims,
) -> impl IntoResponse {
format!(
"(user_id: {}) from hotel {}",
user_id, hotel_id
)
}
pub struct AuthUser(pub Claims); //??
#[derive(Debug, Clone)]
pub struct AuthClaims {
pub user_id: i32,
pub hotel_id: i32,
//pub username: String,
}
impl<S> FromRequestParts<S> for AuthClaims
where
S: Send + Sync + 'static,
AppState: Clone + Send + Sync + 'static, AppState: FromRef<S>
{
type Rejection = (StatusCode, String);
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
// We assume your state has a `jwt_secret` field
let Extension(keys): Extension<JwtKeys> =
Extension::from_request_parts(parts, state).await.map_err(|_| (StatusCode::UNAUTHORIZED, "Missing keys".to_string()))?;
// 1⃣ Extract the token from the Authorization header
let auth_header = parts
.headers
.get("Authorization")
.ok_or((StatusCode::UNAUTHORIZED, "Missing Authorization header".to_string()))?
.to_str()
.map_err(|_| (StatusCode::BAD_REQUEST, "Invalid Authorization header".to_string()))?;
// Bearer token?
let token = auth_header
.strip_prefix("Bearer ")
.ok_or((StatusCode::BAD_REQUEST, "Expected Bearer token".to_string()))?;
// 2⃣ Decode the token
let token_data = decode::<Claims>(
token,
&keys.decoding,
&Validation::new(Algorithm::HS256),
).map_err(|_| (StatusCode::UNAUTHORIZED, "Invalid token".to_string()))?;
Ok(AuthClaims {
user_id: token_data.claims.id,
hotel_id: token_data.claims.hotel_id,
//username: token_data.claims.username,
})
}
}
// Hash a new password
fn hash_password(password: &str) -> anyhow::Result<String> {
let salt = SaltString::generate(&mut OsRng); // unique per password
let argon2 = Argon2::default(); // Argon2id with good defaults
let password_hash = argon2
.hash_password(password.as_bytes(), &salt)
.map_err(|e| anyhow::anyhow!(e))?
.to_string();
Ok(password_hash)
}
// Verify an incoming password against stored hash
fn verify_password(password: &str, stored_hash: &str) -> bool {
let parsed_hash = match PasswordHash::new(&stored_hash) {
Ok(hash) => hash,
Err(_) => return false,
};
Argon2::default()
.verify_password(password.as_bytes(), &parsed_hash).is_ok()
}
#[derive(Deserialize, Debug)]
pub struct RegisterValues{
username: String,
password: String,
hotel_id: i32,
displayname: String,
}
pub struct RegisterPayload(pub RegisterValues);
impl<S> FromRequest<S> for RegisterPayload
where S: Send + Sync,
{
type Rejection = (StatusCode, String);
async fn from_request(req: ExtractRequest, state: &S) -> Result<Self, Self::Rejection> {
let Json(payload) = Json::<RegisterValues>::from_request(req, state)
.await
.map_err(|err| (StatusCode::BAD_REQUEST, format!("Invalid body: {}", err)))?;
Ok(RegisterPayload(payload))
}
}
pub async fn register_user (
State(state): State<AppState>,
RegisterPayload(payload): RegisterPayload
) -> Result<impl IntoResponse, (StatusCode, &'static str)> {
let hashed_password = hash_password(&payload.password)
.map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "Password hashing failed"))?;
let conn = state.logs_pool.get()
.map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "DB connection error"))?;
conn.execute(
"INSERT INTO users (username, password, hotel_id, displayname) VALUES (?1, ?2, ?3, ?4)",
params![payload.username, hashed_password, payload.hotel_id, payload.displayname],
)
.map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "DB insert error"))?;
Ok((StatusCode::CREATED, "User registered successfully"))
}
#[derive(Serialize, Deserialize, Debug)]
pub struct ForceUpdatePasswordValues{
username: String,
newpassword: String,
hotel_id: i32,
admin_pass: String,
}
//pub struct ForceUpdatePasswordPayload (pub ForceUpdatePasswordValues);
pub async fn ForceUpdatePassword(
State(state): State<AppState>,
Json(payload): Json<ForceUpdatePasswordValues>,
) -> impl IntoResponse {
let conn = match state.logs_pool.get() {
Ok(c) => c,
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, "DB, conn failed").into_response()
};
let user_row = match conn.query_row(
"SELECT id FROM users WHERE username = ?1 AND hotel_id = ?2",
params![&payload.username, &payload.hotel_id],
|row|{
let user_id: i32 = row.get(0)?;
//let hotel_id: i32 = row.get(1)?;
Ok((user_id))
},
).optional() {
Ok(opt) => opt,
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, "DB query error")
.into_response(),
};
let (user_id) = match user_row {
Some(u) => u,
None => return (StatusCode::UNAUTHORIZED, "Not correct user")
.into_response(),
};
let admin_check: String = "my_admin_password".to_string();
if &payload.admin_pass != &admin_check {
return (StatusCode::UNAUTHORIZED, "Invalid Amin Password").into_response()
};
let hashed_password = match hash_password(&payload.newpassword) {
Ok(h) => h,
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, "Password hashing failed").into_response(),
};
let result = conn.execute(
"UPDATE users SET password = ?1 WHERE id = ?2",
params![&hashed_password, &user_id],
);
match result {
Ok(rows) if rows > 0 => (StatusCode::OK, "Password updated").into_response(),
Ok(_) => (StatusCode::NOT_FOUND, "User not found").into_response(),
Err(_) => (StatusCode::INTERNAL_SERVER_ERROR, "Failed to update password").into_response(),
}
}
#[derive(Serialize, Deserialize, Debug)]
pub struct UpdatePasswordValues{
username: String,
current_password: String,
newpassword: String,
hotel_id: i32,
}
pub async fn UpdatePassword(
State(state): State<AppState>,
Json(payload): Json<UpdatePasswordValues>,
) -> impl IntoResponse {
let conn = match state.logs_pool.get() {
Ok(c) => c,
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, "DB, conn failed").into_response()
};
let user_row = match conn.query_row(
"SELECT password, id FROM users WHERE username = ?1 AND current_password = ?2",
params![&payload.username, &payload.current_password],
|row|{
let password: String = row.get(0)?;
let id: i32 = row.get(1)?;
Ok((password, id))
},
).optional() {
Ok(opt) => opt,
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, "DB query error")
.into_response(),
};
let (password, user_id) = match user_row {
Some(u) => u,
None => return (StatusCode::UNAUTHORIZED, "Not correct user")
.into_response(),
};
if verify_password( &payload.current_password, &password ) {
return (StatusCode::UNAUTHORIZED, "Invalid Password").into_response()
};
let hashed_password = match hash_password(&payload.newpassword) {
Ok(h) => h,
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, "Password hashing failed").into_response(),
};
let result = conn.execute(
"UPDATE users SET password = ?1 WHERE id = ?2",
params![&hashed_password, &user_id],
);
match result {
Ok(rows) if rows > 0 => (StatusCode::OK, "Password updated").into_response(),
Ok(_) => (StatusCode::NOT_FOUND, "User not found").into_response(),
Err(_) => (StatusCode::INTERNAL_SERVER_ERROR, "Failed to update password").into_response(),
}
}
#[derive(Deserialize, Debug)]
pub struct LoginValues {
username : String,
password : String,
hotel_id: i32,
}
pub struct LoginPayload(pub LoginValues);
impl<S> FromRequest<S> for LoginPayload
where S: Send + Sync,
{
type Rejection = (StatusCode, String);
async fn from_request(req: ExtractRequest, state: &S) -> Result<Self, Self::Rejection> {
let Json(payload) = Json::<LoginValues>::from_request(req, state)
.await
.map_err(|err| (StatusCode::BAD_REQUEST, format!("Invalid body: {}", err)))?;
Ok(LoginPayload(payload))
}
}
#[derive(Deserialize,Debug, Serialize, Clone)]
struct Claims{
id: i32,
hotel_id: i32,
//display_name
username: String,
exp: usize,
}
#[derive(Serialize)]
struct LoginResponse {
token: String,
}
pub async fn clean_auth_loging(
State(state): State<AppState>,
Extension(keys): Extension<JwtKeys>,
LoginPayload(payload): LoginPayload,
) -> impl IntoResponse {
// 1⃣ Get a connection from logs pool
let conn = match state.logs_pool.get() {
Ok(c) => c,
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, "DB connection error").into_response(),
};
let user_row = match conn.query_row(
"SELECT id, password, hotel_id, displayname FROM users WHERE username = ?1",
params![&payload.username],
|row| {
let user_id: i32 = row.get(0)?;
let password: String = row.get(1)?;
let hotel_id: i32 = row.get(2)?;
let displayname: String = row.get(3)?;
Ok((user_id, password, hotel_id, displayname))
},
).optional() {
Ok(opt) => opt,
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, "DB query error").into_response(),
};
let (user_id, stored_hash, hotel_id, displayname) = match user_row {
Some(u) => u,
None => return (StatusCode::UNAUTHORIZED, "Invalid credentials").into_response(),
};
if !verify_password(&payload.password, &stored_hash) {
return (StatusCode::UNAUTHORIZED, "Invelid credentials").into_response();
}
let expiration = match chrono::Utc::now().checked_add_signed(chrono::Duration::hours(15)) {
Some(time) => time.timestamp() as usize,
None => {
// Handle overflow — probably a 500, since this should never happen
return (StatusCode::INTERNAL_SERVER_ERROR, "Time overflow".to_string()).into_response();
}
};
let claims = serde_json::json!({
"id": user_id,
"hotel_id": hotel_id,
"username": payload.username,
"exp": expiration
});
let token = match encode(
&Header::default(),
&claims,
&keys.encoding
) {
Ok(t) => t,
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, "JWT creation failed").into_response(),
};
Json(LoginResponse { token }).into_response()
}
#[derive(Deserialize, Debug)]
pub struct CreateRefreshTokenValue {
pub username: String,
pub password: String,
pub device_id: Uuid,
//pub timestamp: Option<String>,
}
#[axum::debug_handler]
pub async fn create_refresh_token(
State(state): State<AppState>,
user_agent: Option<TypedHeader<UserAgent>>,
Json(payload): Json<CreateRefreshTokenValue>
) -> Result<impl IntoResponse, (StatusCode, String)> { // ← Add Result here
let user_agent_str = user_agent
.map(|ua| ua.to_string())
.unwrap_or_else(|| "Unknown".to_string());
let device_id_str = payload.device_id.to_string();
let argon2 = Argon2::default();
let salt = SaltString::generate(&mut OsRng);
let mut bytes = [0u8; 64];
OsRng.fill_bytes(&mut bytes);
let raw_token = Uuid::new_v4().to_string();
let hashed_token = argon2
.hash_password(raw_token.as_bytes(), &salt)
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
.to_string();
let conn = state.logs_pool.get()
.map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "DB connection error".to_string()))?;
let user_row = conn.query_row(
"SELECT id, password, hotel_id FROM users WHERE username = ?1",
params![&payload.username],
|row| {
let user_id: i32 = row.get(0)?;
let password: String = row.get(1)?;
let hotel_id: i32 = row.get(2)?;
Ok((user_id, password, hotel_id))
},
).optional()
.map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "DB get user id error".to_string()))?;
let (user_id, stored_hash, hotel_id) = user_row
.ok_or((StatusCode::NOT_FOUND, "User not found".to_string()))?;
if !verify_password(&payload.password, &stored_hash) {
return Err((StatusCode::UNAUTHORIZED, "Invalid credentials".to_string()));
}
conn.execute(
"INSERT INTO refresh_token (user_id, token_hash, device_id, user_agent, hotel_id) VALUES (?1, ?2, ?3, ?4, ?5)",
params![user_id, hashed_token, device_id_str, user_agent_str, hotel_id],
)
.map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "DB insert error".to_string()))?;
let cookie_value = format!("refresh_token={}; HttpOnly; Secure; Path=/", raw_token);
let mut response = (StatusCode::CREATED, "Refresh token created successfully").into_response();
response.headers_mut().insert(
SET_COOKIE,
HeaderValue::from_str(&cookie_value).unwrap(),
);
Ok(response) // ← Wrap in Ok()
}
#[derive(Deserialize)]
pub struct LoginRefreshTokenValues{
device_id: Uuid,
refresh_token: String,
}
pub async fn login_refresh_token (
State(state): State<AppState>,
Extension(keys): Extension<JwtKeys>,
user_agent: Option<TypedHeader<UserAgent>>,
Json(payload): Json<LoginRefreshTokenValues>
) -> impl IntoResponse {
let conn = match state.logs_pool.get() {
Ok(c) => c,
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, "DB connection error").into_response(),
};
let user_agent_str = user_agent
.map(|ua| ua.to_string())
.unwrap_or_else(|| "Unknown".to_string());
let device_id_str = payload.device_id.to_string();
//"SELECT user_id, token_hash, hotel_id FROM refresh_token WHERE device_id = ?1 AND user_agent = ?2",
let device_row = match conn.query_row(
"SELECT user_id, token_hash, hotel_id FROM refresh_token WHERE device_id = ?1 AND user_agent = ?2",
params![&device_id_str, &user_agent_str],
|row| {
let user_id: i32 = row.get(0)?;
let token_hash: String = row.get(1)?;
let hotel_id: i32 = row.get(2)?;
//let displayname: String = row.get(3)?;
Ok((user_id, token_hash, hotel_id))
},
).optional() {
Ok(opt) => opt,
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, "DB query error").into_response(),
};
let (user_id, token_hash, hotel_id) = match device_row {
Some(tuple) => tuple,
None => return (StatusCode::UNAUTHORIZED, "No matching device").into_response(),
};
if !verify_password(&payload.refresh_token, &token_hash) {
return (StatusCode::UNAUTHORIZED, "Invalid or mismatched token").into_response();
}
let expiration = match chrono::Utc::now().checked_add_signed(chrono::Duration::hours(15)) {
Some(time) => time.timestamp() as usize,
None => {
// Handle overflow — probably a 500, since this should never happen
return (StatusCode::INTERNAL_SERVER_ERROR, "Time overflow".to_string()).into_response();
}
};
let claims = serde_json::json!({
"id": user_id,
"hotel_id": hotel_id,
//"username": payload.username,
"exp": expiration
});
let token = match encode(
&Header::default(),
&claims,
&keys.encoding
) {
Ok(t) => t,
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, "JWT creation failed").into_response(),
};
Json(LoginResponse { token }).into_response()
}
fn internal_error<E: std::fmt::Display>(err: E) -> (StatusCode, String) {
(StatusCode::INTERNAL_SERVER_ERROR, format!("Internal error: {}", err))
}