548 lines
17 KiB
Rust
548 lines
17 KiB
Rust
use std::time::Duration;
|
||
use axum::{
|
||
body::{to_bytes, Body}, extract::{Extension, FromRequest, FromRequestParts, Path, State}, http::{header::{HeaderValue, SET_COOKIE}, request::Parts, Request as HttpRequest, StatusCode }, middleware::Next, response::{IntoResponse, IntoResponseParts, Response}, Json
|
||
};
|
||
|
||
use axum_extra::extract::TypedHeader;
|
||
//use axum_extra::TypedHeader;
|
||
|
||
use headers::UserAgent;
|
||
|
||
use axum::extract::FromRef;
|
||
use axum::extract::Request as ExtractRequest;
|
||
use jsonwebtoken::{decode, DecodingKey, Validation, encode, EncodingKey, Header, Algorithm};
|
||
use serde::{Deserialize, Serialize};
|
||
use serde_json::Value;
|
||
use chrono::{Utc};
|
||
use rusqlite::{params, Connection, OptionalExtension};
|
||
|
||
use rand_core::{RngCore, OsRng};
|
||
use argon2::{
|
||
password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString},
|
||
Argon2};
|
||
|
||
use uuid::Uuid;
|
||
|
||
//use crate::utils::db_pool::;
|
||
use crate::utils::db_pool::{HotelPool,AppState};
|
||
|
||
use base64::{engine::general_purpose, Engine as _};
|
||
|
||
|
||
#[derive(Clone)]
|
||
pub struct JwtKeys {
|
||
pub encoding: EncodingKey,
|
||
pub decoding: DecodingKey,
|
||
}
|
||
|
||
pub async fn token_tester(
|
||
State(state): State<AppState>,
|
||
//Extension(keys): Extension<JwtKeys>,
|
||
AuthClaims { user_id, hotel_id }: AuthClaims,
|
||
) -> impl IntoResponse {
|
||
format!(
|
||
"(user_id: {}) from hotel {}",
|
||
user_id, hotel_id
|
||
)
|
||
}
|
||
|
||
pub struct AuthUser(pub Claims); //??
|
||
|
||
#[derive(Debug, Clone)]
|
||
pub struct AuthClaims {
|
||
pub user_id: i32,
|
||
pub hotel_id: i32,
|
||
//pub username: String,
|
||
}
|
||
|
||
impl<S> FromRequestParts<S> for AuthClaims
|
||
where
|
||
S: Send + Sync + 'static,
|
||
AppState: Clone + Send + Sync + 'static, AppState: FromRef<S>
|
||
{
|
||
type Rejection = (StatusCode, String);
|
||
|
||
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
|
||
// We assume your state has a `jwt_secret` field
|
||
let Extension(keys): Extension<JwtKeys> =
|
||
Extension::from_request_parts(parts, state).await.map_err(|_| (StatusCode::UNAUTHORIZED, "Missing keys".to_string()))?;
|
||
|
||
// 1️⃣ Extract the token from the Authorization header
|
||
let auth_header = parts
|
||
.headers
|
||
.get("Authorization")
|
||
.ok_or((StatusCode::UNAUTHORIZED, "Missing Authorization header".to_string()))?
|
||
.to_str()
|
||
.map_err(|_| (StatusCode::BAD_REQUEST, "Invalid Authorization header".to_string()))?;
|
||
|
||
// Bearer token?
|
||
let token = auth_header
|
||
.strip_prefix("Bearer ")
|
||
.ok_or((StatusCode::BAD_REQUEST, "Expected Bearer token".to_string()))?;
|
||
|
||
// 2️⃣ Decode the token
|
||
let token_data = decode::<Claims>(
|
||
token,
|
||
&keys.decoding,
|
||
&Validation::new(Algorithm::HS256),
|
||
).map_err(|_| (StatusCode::UNAUTHORIZED, "Invalid token".to_string()))?;
|
||
|
||
Ok(AuthClaims {
|
||
user_id: token_data.claims.id,
|
||
hotel_id: token_data.claims.hotel_id,
|
||
//username: token_data.claims.username,
|
||
})
|
||
|
||
}
|
||
}
|
||
|
||
// Hash a new password
|
||
fn hash_password(password: &str) -> anyhow::Result<String> {
|
||
let salt = SaltString::generate(&mut OsRng); // unique per password
|
||
let argon2 = Argon2::default(); // Argon2id with good defaults
|
||
|
||
let password_hash = argon2
|
||
.hash_password(password.as_bytes(), &salt)
|
||
.map_err(|e| anyhow::anyhow!(e))?
|
||
.to_string();
|
||
|
||
Ok(password_hash)
|
||
}
|
||
|
||
// Verify an incoming password against stored hash
|
||
fn verify_password(password: &str, stored_hash: &str) -> bool {
|
||
|
||
let parsed_hash = match PasswordHash::new(&stored_hash) {
|
||
Ok(hash) => hash,
|
||
Err(_) => return false,
|
||
};
|
||
|
||
Argon2::default()
|
||
.verify_password(password.as_bytes(), &parsed_hash).is_ok()
|
||
|
||
}
|
||
|
||
#[derive(Deserialize, Debug)]
|
||
pub struct RegisterValues{
|
||
username: String,
|
||
password: String,
|
||
hotel_id: i32,
|
||
displayname: String,
|
||
}
|
||
|
||
pub struct RegisterPayload(pub RegisterValues);
|
||
|
||
impl<S> FromRequest<S> for RegisterPayload
|
||
where S: Send + Sync,
|
||
{
|
||
type Rejection = (StatusCode, String);
|
||
|
||
async fn from_request(req: ExtractRequest, state: &S) -> Result<Self, Self::Rejection> {
|
||
let Json(payload) = Json::<RegisterValues>::from_request(req, state)
|
||
.await
|
||
.map_err(|err| (StatusCode::BAD_REQUEST, format!("Invalid body: {}", err)))?;
|
||
Ok(RegisterPayload(payload))
|
||
}
|
||
}
|
||
|
||
pub async fn register_user (
|
||
State(state): State<AppState>,
|
||
RegisterPayload(payload): RegisterPayload
|
||
) -> Result<impl IntoResponse, (StatusCode, &'static str)> {
|
||
|
||
let hashed_password = hash_password(&payload.password)
|
||
.map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "Password hashing failed"))?;
|
||
|
||
let conn = state.logs_pool.get()
|
||
.map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "DB connection error"))?;
|
||
|
||
conn.execute(
|
||
"INSERT INTO users (username, password, hotel_id, displayname) VALUES (?1, ?2, ?3, ?4)",
|
||
params![payload.username, hashed_password, payload.hotel_id, payload.displayname],
|
||
)
|
||
.map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "DB insert error"))?;
|
||
|
||
Ok((StatusCode::CREATED, "User registered successfully"))
|
||
}
|
||
|
||
#[derive(Serialize, Deserialize, Debug)]
|
||
pub struct ForceUpdatePasswordValues{
|
||
username: String,
|
||
newpassword: String,
|
||
hotel_id: i32,
|
||
admin_pass: String,
|
||
}
|
||
|
||
//pub struct ForceUpdatePasswordPayload (pub ForceUpdatePasswordValues);
|
||
|
||
pub async fn ForceUpdatePassword(
|
||
State(state): State<AppState>,
|
||
Json(payload): Json<ForceUpdatePasswordValues>,
|
||
) -> impl IntoResponse {
|
||
|
||
let conn = match state.logs_pool.get() {
|
||
Ok(c) => c,
|
||
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, "DB, conn failed").into_response()
|
||
};
|
||
|
||
let user_row = match conn.query_row(
|
||
"SELECT id FROM users WHERE username = ?1 AND hotel_id = ?2",
|
||
params![&payload.username, &payload.hotel_id],
|
||
|row|{
|
||
let user_id: i32 = row.get(0)?;
|
||
//let hotel_id: i32 = row.get(1)?;
|
||
Ok((user_id))
|
||
},
|
||
).optional() {
|
||
Ok(opt) => opt,
|
||
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, "DB query error")
|
||
.into_response(),
|
||
};
|
||
|
||
let (user_id) = match user_row {
|
||
Some(u) => u,
|
||
None => return (StatusCode::UNAUTHORIZED, "Not correct user")
|
||
.into_response(),
|
||
};
|
||
|
||
let admin_check: String = "my_admin_password".to_string();
|
||
|
||
if &payload.admin_pass != &admin_check {
|
||
return (StatusCode::UNAUTHORIZED, "Invalid Amin Password").into_response()
|
||
};
|
||
|
||
|
||
let hashed_password = match hash_password(&payload.newpassword) {
|
||
Ok(h) => h,
|
||
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, "Password hashing failed").into_response(),
|
||
};
|
||
|
||
let result = conn.execute(
|
||
"UPDATE users SET password = ?1 WHERE id = ?2",
|
||
params![&hashed_password, &user_id],
|
||
);
|
||
|
||
match result {
|
||
Ok(rows) if rows > 0 => (StatusCode::OK, "Password updated").into_response(),
|
||
Ok(_) => (StatusCode::NOT_FOUND, "User not found").into_response(),
|
||
Err(_) => (StatusCode::INTERNAL_SERVER_ERROR, "Failed to update password").into_response(),
|
||
}
|
||
|
||
}
|
||
|
||
#[derive(Serialize, Deserialize, Debug)]
|
||
pub struct UpdatePasswordValues{
|
||
username: String,
|
||
current_password: String,
|
||
newpassword: String,
|
||
hotel_id: i32,
|
||
|
||
}
|
||
|
||
pub async fn UpdatePassword(
|
||
State(state): State<AppState>,
|
||
Json(payload): Json<UpdatePasswordValues>,
|
||
) -> impl IntoResponse {
|
||
|
||
let conn = match state.logs_pool.get() {
|
||
Ok(c) => c,
|
||
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, "DB, conn failed").into_response()
|
||
};
|
||
|
||
let user_row = match conn.query_row(
|
||
"SELECT password, id FROM users WHERE username = ?1 AND current_password = ?2",
|
||
params![&payload.username, &payload.current_password],
|
||
|row|{
|
||
let password: String = row.get(0)?;
|
||
let id: i32 = row.get(1)?;
|
||
Ok((password, id))
|
||
},
|
||
).optional() {
|
||
Ok(opt) => opt,
|
||
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, "DB query error")
|
||
.into_response(),
|
||
};
|
||
|
||
let (password, user_id) = match user_row {
|
||
Some(u) => u,
|
||
None => return (StatusCode::UNAUTHORIZED, "Not correct user")
|
||
.into_response(),
|
||
};
|
||
|
||
if verify_password( &payload.current_password, &password ) {
|
||
return (StatusCode::UNAUTHORIZED, "Invalid Password").into_response()
|
||
};
|
||
|
||
|
||
let hashed_password = match hash_password(&payload.newpassword) {
|
||
Ok(h) => h,
|
||
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, "Password hashing failed").into_response(),
|
||
};
|
||
|
||
let result = conn.execute(
|
||
"UPDATE users SET password = ?1 WHERE id = ?2",
|
||
params![&hashed_password, &user_id],
|
||
);
|
||
|
||
match result {
|
||
Ok(rows) if rows > 0 => (StatusCode::OK, "Password updated").into_response(),
|
||
Ok(_) => (StatusCode::NOT_FOUND, "User not found").into_response(),
|
||
Err(_) => (StatusCode::INTERNAL_SERVER_ERROR, "Failed to update password").into_response(),
|
||
}
|
||
|
||
}
|
||
|
||
#[derive(Deserialize, Debug)]
|
||
pub struct LoginValues {
|
||
username : String,
|
||
password : String,
|
||
hotel_id: i32,
|
||
}
|
||
|
||
pub struct LoginPayload(pub LoginValues);
|
||
|
||
impl<S> FromRequest<S> for LoginPayload
|
||
where S: Send + Sync,
|
||
{
|
||
type Rejection = (StatusCode, String);
|
||
|
||
async fn from_request(req: ExtractRequest, state: &S) -> Result<Self, Self::Rejection> {
|
||
let Json(payload) = Json::<LoginValues>::from_request(req, state)
|
||
.await
|
||
.map_err(|err| (StatusCode::BAD_REQUEST, format!("Invalid body: {}", err)))?;
|
||
|
||
Ok(LoginPayload(payload))
|
||
}
|
||
}
|
||
|
||
#[derive(Deserialize,Debug, Serialize, Clone)]
|
||
struct Claims{
|
||
id: i32,
|
||
hotel_id: i32,
|
||
//display_name
|
||
username: String,
|
||
exp: usize,
|
||
}
|
||
|
||
#[derive(Serialize)]
|
||
struct LoginResponse {
|
||
token: String,
|
||
}
|
||
|
||
pub async fn clean_auth_loging(
|
||
State(state): State<AppState>,
|
||
Extension(keys): Extension<JwtKeys>,
|
||
LoginPayload(payload): LoginPayload,
|
||
) -> impl IntoResponse {
|
||
// 1️⃣ Get a connection from logs pool
|
||
let conn = match state.logs_pool.get() {
|
||
Ok(c) => c,
|
||
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, "DB connection error").into_response(),
|
||
};
|
||
|
||
let user_row = match conn.query_row(
|
||
"SELECT id, password, hotel_id, displayname FROM users WHERE username = ?1",
|
||
params![&payload.username],
|
||
|row| {
|
||
let user_id: i32 = row.get(0)?;
|
||
let password: String = row.get(1)?;
|
||
let hotel_id: i32 = row.get(2)?;
|
||
let displayname: String = row.get(3)?;
|
||
Ok((user_id, password, hotel_id, displayname))
|
||
},
|
||
).optional() {
|
||
Ok(opt) => opt,
|
||
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, "DB query error").into_response(),
|
||
};
|
||
|
||
let (user_id, stored_hash, hotel_id, displayname) = match user_row {
|
||
Some(u) => u,
|
||
None => return (StatusCode::UNAUTHORIZED, "Invalid credentials").into_response(),
|
||
};
|
||
|
||
if !verify_password(&payload.password, &stored_hash) {
|
||
return (StatusCode::UNAUTHORIZED, "Invelid credentials").into_response();
|
||
}
|
||
|
||
|
||
let expiration = match chrono::Utc::now().checked_add_signed(chrono::Duration::hours(15)) {
|
||
Some(time) => time.timestamp() as usize,
|
||
None => {
|
||
// Handle overflow — probably a 500, since this should never happen
|
||
return (StatusCode::INTERNAL_SERVER_ERROR, "Time overflow".to_string()).into_response();
|
||
}
|
||
};
|
||
|
||
let claims = serde_json::json!({
|
||
"id": user_id,
|
||
"hotel_id": hotel_id,
|
||
"username": payload.username,
|
||
"exp": expiration
|
||
});
|
||
|
||
let token = match encode(
|
||
&Header::default(),
|
||
&claims,
|
||
&keys.encoding
|
||
) {
|
||
Ok(t) => t,
|
||
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, "JWT creation failed").into_response(),
|
||
};
|
||
Json(LoginResponse { token }).into_response()
|
||
}
|
||
|
||
#[derive(Deserialize, Debug)]
|
||
pub struct CreateRefreshTokenValue {
|
||
pub username: String,
|
||
pub password: String,
|
||
pub device_id: Uuid,
|
||
//pub timestamp: Option<String>,
|
||
|
||
}
|
||
|
||
#[axum::debug_handler]
|
||
pub async fn create_refresh_token(
|
||
State(state): State<AppState>,
|
||
user_agent: Option<TypedHeader<UserAgent>>,
|
||
Json(payload): Json<CreateRefreshTokenValue>
|
||
) -> Result<impl IntoResponse, (StatusCode, String)> { // ← Add Result here
|
||
|
||
let user_agent_str = user_agent
|
||
.map(|ua| ua.to_string())
|
||
.unwrap_or_else(|| "Unknown".to_string());
|
||
|
||
let device_id_str = payload.device_id.to_string();
|
||
|
||
let argon2 = Argon2::default();
|
||
let salt = SaltString::generate(&mut OsRng);
|
||
let mut bytes = [0u8; 64];
|
||
OsRng.fill_bytes(&mut bytes);
|
||
|
||
let raw_token = Uuid::new_v4().to_string();
|
||
|
||
let hashed_token = argon2
|
||
.hash_password(raw_token.as_bytes(), &salt)
|
||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
|
||
.to_string();
|
||
|
||
let conn = state.logs_pool.get()
|
||
.map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "DB connection error".to_string()))?;
|
||
|
||
let user_row = conn.query_row(
|
||
"SELECT id, password, hotel_id FROM users WHERE username = ?1",
|
||
params![&payload.username],
|
||
|row| {
|
||
let user_id: i32 = row.get(0)?;
|
||
let password: String = row.get(1)?;
|
||
let hotel_id: i32 = row.get(2)?;
|
||
Ok((user_id, password, hotel_id))
|
||
},
|
||
).optional()
|
||
.map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "DB get user id error".to_string()))?;
|
||
|
||
let (user_id, stored_hash, hotel_id) = user_row
|
||
.ok_or((StatusCode::NOT_FOUND, "User not found".to_string()))?;
|
||
|
||
if !verify_password(&payload.password, &stored_hash) {
|
||
return Err((StatusCode::UNAUTHORIZED, "Invalid credentials".to_string()));
|
||
}
|
||
|
||
conn.execute(
|
||
"INSERT INTO refresh_token (user_id, token_hash, device_id, user_agent, hotel_id) VALUES (?1, ?2, ?3, ?4, ?5)",
|
||
params![user_id, hashed_token, device_id_str, user_agent_str, hotel_id],
|
||
)
|
||
.map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "DB insert error".to_string()))?;
|
||
|
||
let cookie_value = format!("refresh_token={}; HttpOnly; Secure; Path=/", raw_token);
|
||
|
||
let mut response = (StatusCode::CREATED, "Refresh token created successfully").into_response();
|
||
response.headers_mut().insert(
|
||
SET_COOKIE,
|
||
HeaderValue::from_str(&cookie_value).unwrap(),
|
||
);
|
||
|
||
Ok(response) // ← Wrap in Ok()
|
||
}
|
||
|
||
#[derive(Deserialize)]
|
||
pub struct LoginRefreshTokenValues{
|
||
device_id: Uuid,
|
||
refresh_token: String,
|
||
}
|
||
|
||
pub async fn login_refresh_token (
|
||
State(state): State<AppState>,
|
||
Extension(keys): Extension<JwtKeys>,
|
||
user_agent: Option<TypedHeader<UserAgent>>,
|
||
Json(payload): Json<LoginRefreshTokenValues>
|
||
) -> impl IntoResponse {
|
||
|
||
let conn = match state.logs_pool.get() {
|
||
Ok(c) => c,
|
||
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, "DB connection error").into_response(),
|
||
};
|
||
|
||
let user_agent_str = user_agent
|
||
.map(|ua| ua.to_string())
|
||
.unwrap_or_else(|| "Unknown".to_string());
|
||
|
||
let device_id_str = payload.device_id.to_string();
|
||
|
||
//"SELECT user_id, token_hash, hotel_id FROM refresh_token WHERE device_id = ?1 AND user_agent = ?2",
|
||
|
||
let device_row = match conn.query_row(
|
||
"SELECT user_id, token_hash, hotel_id FROM refresh_token WHERE device_id = ?1 AND user_agent = ?2",
|
||
params![&device_id_str, &user_agent_str],
|
||
|row| {
|
||
let user_id: i32 = row.get(0)?;
|
||
let token_hash: String = row.get(1)?;
|
||
let hotel_id: i32 = row.get(2)?;
|
||
//let displayname: String = row.get(3)?;
|
||
Ok((user_id, token_hash, hotel_id))
|
||
},
|
||
).optional() {
|
||
Ok(opt) => opt,
|
||
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, "DB query error").into_response(),
|
||
};
|
||
|
||
|
||
let (user_id, token_hash, hotel_id) = match device_row {
|
||
Some(tuple) => tuple,
|
||
None => return (StatusCode::UNAUTHORIZED, "No matching device").into_response(),
|
||
};
|
||
|
||
|
||
if !verify_password(&payload.refresh_token, &token_hash) {
|
||
return (StatusCode::UNAUTHORIZED, "Invalid or mismatched token").into_response();
|
||
}
|
||
|
||
|
||
let expiration = match chrono::Utc::now().checked_add_signed(chrono::Duration::hours(15)) {
|
||
Some(time) => time.timestamp() as usize,
|
||
None => {
|
||
// Handle overflow — probably a 500, since this should never happen
|
||
return (StatusCode::INTERNAL_SERVER_ERROR, "Time overflow".to_string()).into_response();
|
||
}
|
||
};
|
||
|
||
let claims = serde_json::json!({
|
||
"id": user_id,
|
||
"hotel_id": hotel_id,
|
||
//"username": payload.username,
|
||
"exp": expiration
|
||
});
|
||
|
||
let token = match encode(
|
||
&Header::default(),
|
||
&claims,
|
||
&keys.encoding
|
||
) {
|
||
Ok(t) => t,
|
||
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, "JWT creation failed").into_response(),
|
||
};
|
||
Json(LoginResponse { token }).into_response()
|
||
}
|
||
|
||
fn internal_error<E: std::fmt::Display>(err: E) -> (StatusCode, String) {
|
||
(StatusCode::INTERNAL_SERVER_ERROR, format!("Internal error: {}", err))
|
||
} |