1181 lines
34 KiB
Rust
1181 lines
34 KiB
Rust
use axum::{
|
||
Json,
|
||
body::{Body, to_bytes},
|
||
extract::{Extension, FromRequest, FromRequestParts, Path, State, ws::close_code::STATUS},
|
||
http::{
|
||
Request as HttpRequest, StatusCode,
|
||
header::{HeaderValue, SET_COOKIE},
|
||
request::Parts,
|
||
status,
|
||
},
|
||
middleware::Next,
|
||
response::{IntoResponse, IntoResponseParts, Response},
|
||
};
|
||
use std::time::Duration;
|
||
|
||
use axum_extra::extract::TypedHeader;
|
||
//use axum_extra::TypedHeader;
|
||
|
||
use futures_util::future::TrySelect;
|
||
use headers::{Cookie, UserAgent};
|
||
|
||
use axum::extract::FromRef;
|
||
use axum::extract::Request as ExtractRequest;
|
||
use chrono::{Utc, format};
|
||
use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation, decode, encode};
|
||
use reqwest::header::REFRESH;
|
||
use rusqlite::{Connection, OptionalExtension, params};
|
||
use serde::{Deserialize, Serialize};
|
||
use serde_json::Value;
|
||
|
||
use argon2::{
|
||
Argon2,
|
||
password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString},
|
||
};
|
||
use rand_core::{OsRng, RngCore};
|
||
|
||
use uuid::Uuid;
|
||
|
||
//use crate::utils::db_pool::;
|
||
use crate::utils::db_pool::{AppState, HotelPool};
|
||
|
||
use base64::{Engine as _, engine::general_purpose};
|
||
|
||
#[derive(Clone)]
|
||
pub struct JwtKeys {
|
||
pub encoding: EncodingKey,
|
||
pub decoding: DecodingKey,
|
||
}
|
||
|
||
pub async fn token_tester(
|
||
State(state): State<AppState>,
|
||
//Extension(keys): Extension<JwtKeys>,
|
||
AuthClaims { user_id, hotel_id }: AuthClaims,
|
||
) -> impl IntoResponse {
|
||
format!("(user_id: {}) from hotel {}", user_id, hotel_id)
|
||
}
|
||
|
||
pub struct AuthUser(pub Claims); //??
|
||
|
||
#[derive(Debug, Clone)]
|
||
pub struct AuthClaims {
|
||
pub user_id: i32,
|
||
pub hotel_id: i32,
|
||
//pub username: String,
|
||
}
|
||
|
||
pub fn auth_claims_from_token(
|
||
token: &str,
|
||
keys: &JwtKeys,
|
||
) -> Result<AuthClaims, (StatusCode, String)> {
|
||
let token_data = decode::<Claims>(token, &keys.decoding, &Validation::new(Algorithm::HS256))
|
||
.map_err(|_| (StatusCode::UNAUTHORIZED, "Invalid token".into()))?;
|
||
|
||
Ok(AuthClaims {
|
||
user_id: token_data.claims.id,
|
||
hotel_id: token_data.claims.hotel_id,
|
||
})
|
||
}
|
||
|
||
impl<S> FromRequestParts<S> for AuthClaims
|
||
where
|
||
S: Send + Sync + 'static,
|
||
AppState: FromRef<S>,
|
||
{
|
||
type Rejection = (StatusCode, String);
|
||
|
||
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
|
||
let Extension(keys): Extension<JwtKeys> = Extension::from_request_parts(parts, state)
|
||
.await
|
||
.map_err(|_| (StatusCode::UNAUTHORIZED, "Missing keys".into()))?;
|
||
|
||
let auth_header = parts
|
||
.headers
|
||
.get(axum::http::header::AUTHORIZATION)
|
||
.ok_or((
|
||
StatusCode::UNAUTHORIZED,
|
||
"Missing Authorization header".into(),
|
||
))?
|
||
.to_str()
|
||
.map_err(|_| {
|
||
(
|
||
StatusCode::BAD_REQUEST,
|
||
"Invalid Authorization header".into(),
|
||
)
|
||
})?;
|
||
|
||
let token = auth_header
|
||
.strip_prefix("Bearer ")
|
||
.ok_or((StatusCode::BAD_REQUEST, "Expected Bearer token".into()))?;
|
||
|
||
auth_claims_from_token(token, &keys)
|
||
}
|
||
}
|
||
|
||
// Hash a new password
|
||
fn hash_password(password: &str) -> anyhow::Result<String> {
|
||
let salt = SaltString::generate(&mut OsRng); // unique per password
|
||
let argon2 = Argon2::default(); // Argon2id with good defaults
|
||
|
||
let password_hash = argon2
|
||
.hash_password(password.as_bytes(), &salt)
|
||
.map_err(|e| anyhow::anyhow!(e))?
|
||
.to_string();
|
||
|
||
Ok(password_hash)
|
||
}
|
||
|
||
// Verify an incoming password against stored hash
|
||
fn verify_password(password: &str, stored_hash: &str) -> bool {
|
||
let parsed_hash = match PasswordHash::new(&stored_hash) {
|
||
Ok(hash) => hash,
|
||
Err(_) => return false,
|
||
};
|
||
|
||
Argon2::default()
|
||
.verify_password(password.as_bytes(), &parsed_hash)
|
||
.is_ok()
|
||
}
|
||
|
||
#[derive(Deserialize, Debug)]
|
||
pub struct RegisterValues {
|
||
username: String,
|
||
password: String,
|
||
#[serde(default)]
|
||
hotel_ids: Vec<i32>, //-> :Vec!<32>, maybe optionnal ?
|
||
displayname: String,
|
||
}
|
||
|
||
pub struct RegisterPayload(pub RegisterValues);
|
||
|
||
impl<S> FromRequest<S> for RegisterPayload
|
||
where
|
||
S: Send + Sync,
|
||
{
|
||
type Rejection = (StatusCode, String);
|
||
|
||
async fn from_request(req: ExtractRequest, state: &S) -> Result<Self, Self::Rejection> {
|
||
let Json(payload) = Json::<RegisterValues>::from_request(req, state)
|
||
.await
|
||
.map_err(|err| (StatusCode::BAD_REQUEST, format!("Invalid body: {}", err)))?;
|
||
Ok(RegisterPayload(payload))
|
||
}
|
||
}
|
||
|
||
//TODO: Validate all hotel_ids first + Use a transaction + Batch query hotel names with IN (...)
|
||
|
||
pub async fn register_user(
|
||
State(state): State<AppState>,
|
||
RegisterPayload(payload): RegisterPayload,
|
||
) -> Result<impl IntoResponse, (StatusCode, String)> {
|
||
let hashed_password = hash_password(&payload.password).map_err(|e| {
|
||
(
|
||
StatusCode::INTERNAL_SERVER_ERROR,
|
||
format!("Password hashing failed: {}", e),
|
||
)
|
||
})?;
|
||
|
||
let conn = state.logs_pool.get().map_err(|e| {
|
||
(
|
||
StatusCode::INTERNAL_SERVER_ERROR,
|
||
format!("DB connection error: {}", e),
|
||
)
|
||
})?;
|
||
|
||
conn.execute(
|
||
"INSERT INTO users (username, password, displayname)
|
||
VALUES (?1, ?2, ?3)",
|
||
params![payload.username, hashed_password, payload.displayname],
|
||
)
|
||
.map_err(|e| {
|
||
(
|
||
StatusCode::INTERNAL_SERVER_ERROR,
|
||
format!("User insert error: {}", e),
|
||
)
|
||
})?;
|
||
|
||
let user_id = conn.last_insert_rowid();
|
||
|
||
for &hotel_id in &payload.hotel_ids {
|
||
// more logic for security here
|
||
//FIXME: needs to be the display name in the DB, scheme is currently wrong
|
||
|
||
let hotel_name: String = conn
|
||
.query_row(
|
||
"SELECT hotelname FROM hotels
|
||
WHERE id = ?1 ",
|
||
params![hotel_id],
|
||
|row| row.get(0),
|
||
)
|
||
.map_err(|e| {
|
||
(
|
||
StatusCode::BAD_REQUEST,
|
||
format!("Invalid hotel id {}: {}", hotel_id, e),
|
||
)
|
||
})?;
|
||
|
||
conn.execute(
|
||
"INSERT INTO hotel_user_link (user_id, hotel_id, username, hotelname)
|
||
VALUES (?1, ?2, ?3, ?4)",
|
||
params![user_id, hotel_id, payload.username, hotel_name],
|
||
)
|
||
.map_err(|e| {
|
||
(
|
||
StatusCode::INTERNAL_SERVER_ERROR,
|
||
format!(
|
||
"Link insert error for user_id={} hotel_id={}: {}",
|
||
user_id, hotel_id, e
|
||
),
|
||
)
|
||
})?;
|
||
}
|
||
|
||
Ok((
|
||
StatusCode::CREATED,
|
||
"User registered successfully".to_string(),
|
||
))
|
||
}
|
||
|
||
#[derive(Serialize, Deserialize, Debug)]
|
||
pub struct ForceUpdatePasswordValues {
|
||
username: String,
|
||
newpassword: String,
|
||
hotel_id: i32,
|
||
admin_pass: String,
|
||
}
|
||
|
||
//pub struct ForceUpdatePasswordPayload (pub ForceUpdatePasswordValues);
|
||
|
||
pub async fn force_update_password(
|
||
State(state): State<AppState>,
|
||
Json(payload): Json<ForceUpdatePasswordValues>,
|
||
) -> impl IntoResponse {
|
||
let conn = match state.logs_pool.get() {
|
||
Ok(c) => c,
|
||
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, "DB, conn failed").into_response(),
|
||
};
|
||
|
||
let user_row = match conn
|
||
.query_row(
|
||
"SELECT id FROM users WHERE username = ?1 AND hotel_id = ?2",
|
||
params![&payload.username, &payload.hotel_id],
|
||
|row| {
|
||
let user_id: i32 = row.get(0)?;
|
||
//let hotel_id: i32 = row.get(1)?;
|
||
Ok(user_id)
|
||
},
|
||
)
|
||
.optional()
|
||
{
|
||
Ok(opt) => opt,
|
||
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, "DB query error").into_response(),
|
||
};
|
||
|
||
let user_id = match user_row {
|
||
Some(u) => u,
|
||
None => return (StatusCode::UNAUTHORIZED, "Not correct user").into_response(),
|
||
};
|
||
|
||
let admin_check: String = "my_admin_password".to_string();
|
||
|
||
if &payload.admin_pass != &admin_check {
|
||
return (StatusCode::UNAUTHORIZED, "Invalid Amin Password").into_response();
|
||
};
|
||
|
||
let hashed_password = match hash_password(&payload.newpassword) {
|
||
Ok(h) => h,
|
||
Err(_) => {
|
||
return (StatusCode::INTERNAL_SERVER_ERROR, "Password hashing failed").into_response();
|
||
}
|
||
};
|
||
|
||
let result = conn.execute(
|
||
"UPDATE users SET password = ?1 WHERE id = ?2",
|
||
params![&hashed_password, &user_id],
|
||
);
|
||
|
||
match result {
|
||
Ok(rows) if rows > 0 => (StatusCode::OK, "Password updated").into_response(),
|
||
Ok(_) => (StatusCode::NOT_FOUND, "User not found").into_response(),
|
||
Err(_) => (
|
||
StatusCode::INTERNAL_SERVER_ERROR,
|
||
"Failed to update password",
|
||
)
|
||
.into_response(),
|
||
}
|
||
}
|
||
|
||
#[derive(Serialize, Deserialize, Debug)]
|
||
pub struct UpdatePasswordValues {
|
||
username: String,
|
||
current_password: String,
|
||
newpassword: String,
|
||
//hotel_id: i32,
|
||
}
|
||
|
||
pub async fn update_password(
|
||
State(state): State<AppState>,
|
||
Json(payload): Json<UpdatePasswordValues>,
|
||
) -> impl IntoResponse {
|
||
let conn = match state.logs_pool.get() {
|
||
Ok(c) => c,
|
||
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, "DB, conn failed").into_response(),
|
||
};
|
||
|
||
let user_row = match conn
|
||
.query_row(
|
||
"SELECT password, id FROM users WHERE username = ?1",
|
||
params![&payload.username],
|
||
|row| {
|
||
let password: String = row.get(0)?;
|
||
let id: i32 = row.get(1)?;
|
||
Ok((password, id))
|
||
},
|
||
)
|
||
.optional()
|
||
{
|
||
Ok(opt) => opt,
|
||
Err(e) => {
|
||
return (
|
||
StatusCode::INTERNAL_SERVER_ERROR,
|
||
format!("DB query error: {}", e),
|
||
)
|
||
.into_response();
|
||
}
|
||
};
|
||
|
||
let (password, user_id) = match user_row {
|
||
Some(u) => u,
|
||
None => return (StatusCode::UNAUTHORIZED, "Not correct user").into_response(),
|
||
};
|
||
|
||
if !verify_password(&payload.current_password, &password) {
|
||
return (StatusCode::UNAUTHORIZED, "Invalid Password").into_response();
|
||
};
|
||
|
||
let hashed_password = match hash_password(&payload.newpassword) {
|
||
Ok(h) => h,
|
||
Err(_) => {
|
||
return (StatusCode::INTERNAL_SERVER_ERROR, "Password hashing failed").into_response();
|
||
}
|
||
};
|
||
|
||
let result = conn.execute(
|
||
"UPDATE users SET password = ?1 WHERE id = ?2",
|
||
params![&hashed_password, &user_id],
|
||
);
|
||
|
||
match result {
|
||
Ok(rows) if rows > 0 => (StatusCode::OK, "Password updated").into_response(),
|
||
Ok(_) => (StatusCode::NOT_FOUND, "User not found").into_response(),
|
||
Err(_) => (
|
||
StatusCode::INTERNAL_SERVER_ERROR,
|
||
"Failed to update password",
|
||
)
|
||
.into_response(),
|
||
}
|
||
}
|
||
|
||
#[derive(Deserialize, Debug)]
|
||
pub struct LoginValues {
|
||
username: String,
|
||
password: String,
|
||
//hotel_id: i32,
|
||
}
|
||
|
||
pub struct LoginPayload(pub LoginValues);
|
||
|
||
impl<S> FromRequest<S> for LoginPayload
|
||
where
|
||
S: Send + Sync,
|
||
{
|
||
type Rejection = (StatusCode, String);
|
||
|
||
async fn from_request(req: ExtractRequest, state: &S) -> Result<Self, Self::Rejection> {
|
||
let Json(payload) = Json::<LoginValues>::from_request(req, state)
|
||
.await
|
||
.map_err(|err| (StatusCode::BAD_REQUEST, format!("Invalid body: {}", err)))?;
|
||
|
||
Ok(LoginPayload(payload))
|
||
}
|
||
}
|
||
|
||
#[derive(Deserialize, Debug, Serialize, Clone)]
|
||
struct Claims {
|
||
id: i32,
|
||
hotel_id: i32,
|
||
//display_name
|
||
//username: String,
|
||
exp: usize,
|
||
}
|
||
|
||
#[derive(Serialize)]
|
||
struct LoginResponse {
|
||
token: String,
|
||
}
|
||
|
||
#[derive(Serialize)]
|
||
struct MultiLoginResponse {
|
||
user_id: i32,
|
||
tokens: Vec<String>,
|
||
}
|
||
|
||
pub async fn clean_auth_loging(
|
||
State(state): State<AppState>,
|
||
Extension(keys): Extension<JwtKeys>,
|
||
LoginPayload(payload): LoginPayload,
|
||
) -> impl IntoResponse {
|
||
// 1️⃣ Get a connection from logs pool
|
||
let conn = match state.logs_pool.get() {
|
||
Ok(c) => c,
|
||
Err(_) => {
|
||
return (StatusCode::INTERNAL_SERVER_ERROR, "DB connection error").into_response();
|
||
}
|
||
};
|
||
|
||
let user_row = match conn
|
||
.query_row(
|
||
"SELECT id, password, hotel_id, displayname FROM users WHERE username = ?1",
|
||
params![&payload.username],
|
||
|row| {
|
||
let user_id: i32 = row.get(0)?;
|
||
let password: String = row.get(1)?;
|
||
let hotel_id: i32 = row.get(2)?;
|
||
let displayname: String = row.get(3)?;
|
||
Ok((user_id, password, hotel_id, displayname))
|
||
},
|
||
)
|
||
.optional()
|
||
{
|
||
Ok(opt) => opt,
|
||
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, "DB query error").into_response(),
|
||
};
|
||
|
||
let (user_id, stored_hash, hotel_id, _displayname) = match user_row {
|
||
Some(u) => u,
|
||
None => return (StatusCode::UNAUTHORIZED, "Invalid credentials").into_response(),
|
||
};
|
||
|
||
if !verify_password(&payload.password, &stored_hash) {
|
||
return (StatusCode::UNAUTHORIZED, "Invalid credentials").into_response();
|
||
}
|
||
|
||
let expiration = match chrono::Utc::now().checked_add_signed(chrono::Duration::hours(15)) {
|
||
Some(time) => time.timestamp() as usize,
|
||
None => {
|
||
// Handle overflow — probably a 500, since this should never happen
|
||
return (
|
||
StatusCode::INTERNAL_SERVER_ERROR,
|
||
"Time overflow".to_string(),
|
||
)
|
||
.into_response();
|
||
}
|
||
};
|
||
|
||
let claims = serde_json::json!({
|
||
"id": user_id,
|
||
"hotel_id": hotel_id,
|
||
//"username": payload.username,
|
||
"exp": expiration
|
||
});
|
||
|
||
let token = match encode(&Header::default(), &claims, &keys.encoding) {
|
||
Ok(t) => t,
|
||
Err(_) => {
|
||
return (StatusCode::INTERNAL_SERVER_ERROR, "JWT creation failed").into_response();
|
||
}
|
||
};
|
||
Json(LoginResponse { token }).into_response()
|
||
}
|
||
|
||
#[derive(Deserialize, Debug)]
|
||
pub struct CreateRefreshTokenValue {
|
||
pub username: String,
|
||
pub password: String,
|
||
pub device_id: Uuid,
|
||
//pub timestamp: Option<String>,
|
||
}
|
||
|
||
//FIXME: weird return type, returning result ?
|
||
|
||
#[axum::debug_handler]
|
||
pub async fn create_refresh_token(
|
||
State(state): State<AppState>,
|
||
user_agent: Option<TypedHeader<UserAgent>>,
|
||
Json(payload): Json<CreateRefreshTokenValue>,
|
||
) -> Result<impl IntoResponse, (StatusCode, String)> {
|
||
// ← Add Result here
|
||
|
||
let user_agent_str = user_agent
|
||
.map(|ua| ua.to_string())
|
||
.unwrap_or_else(|| "Unknown".to_string());
|
||
|
||
let device_id_str = payload.device_id.to_string();
|
||
|
||
let conn = state.logs_pool.get().map_err(|_| {
|
||
(
|
||
StatusCode::INTERNAL_SERVER_ERROR,
|
||
"DB connection error".to_string(),
|
||
)
|
||
})?;
|
||
|
||
let argon2 = Argon2::default();
|
||
let salt = SaltString::generate(&mut OsRng);
|
||
let mut bytes = [0u8; 64];
|
||
OsRng.fill_bytes(&mut bytes);
|
||
let raw_token = Uuid::new_v4().to_string();
|
||
|
||
let hashed_token = &raw_token;
|
||
/*
|
||
let hashed_token = argon2
|
||
.hash_password(raw_token.as_bytes(), &salt)
|
||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
|
||
.to_string();
|
||
*/
|
||
|
||
// let mut stmt = conn.prepare(
|
||
// "SELECT id, password FROM users WHERE username = ?1"
|
||
|
||
let credentials = match conn.query_row(
|
||
"SELECT id, password FROM users WHERE username = ?1",
|
||
params![&payload.username],
|
||
|row| {
|
||
let user_id: i32 = row.get(0)?;
|
||
let password: String = row.get(1)?;
|
||
Ok((user_id, password))
|
||
},
|
||
) {
|
||
Ok(cr) => cr,
|
||
Err(e) => {
|
||
return Err((
|
||
StatusCode::INTERNAL_SERVER_ERROR,
|
||
format!("error fetching credentials: {e}").to_string(),
|
||
));
|
||
}
|
||
};
|
||
|
||
let (user_id, user_password) = credentials;
|
||
|
||
/*
|
||
let (user_id, stored_hash, hotel_id) = user_row
|
||
.ok_or((StatusCode::NOT_FOUND, "User not found".to_string()))?;
|
||
*/
|
||
//let mut tokens = Vec::new();
|
||
//TODO: validate password
|
||
if !verify_password(&payload.password, &user_password) {
|
||
return Err((
|
||
StatusCode::INTERNAL_SERVER_ERROR,
|
||
"Invalid credential".to_string(),
|
||
)); // Skip rows with invalid password
|
||
}
|
||
|
||
//TODO: get hotel name to return a map/tuple of hotel name
|
||
let mut stmt = match conn.prepare("SELECt hotel_id FROM hotel_user_link WHERE user_id = ?1") {
|
||
Ok(stmt) => stmt,
|
||
Err(_) => {
|
||
return Err((
|
||
StatusCode::INTERNAL_SERVER_ERROR,
|
||
"error building user_id fetch stmt".to_string(),
|
||
));
|
||
}
|
||
};
|
||
|
||
//TODO: compiler les hotel id dans un vecteur pour le feed dans le refresh token
|
||
let hotel_ids: Vec<i32> = match stmt.query_map(params![&user_id], |row| row.get(0)) {
|
||
Ok(rows) => rows.collect::<Result<Vec<_>, _>>().map_err(|_| {
|
||
(
|
||
StatusCode::INTERNAL_SERVER_ERROR,
|
||
"Error collecting hotel_ids".to_string(),
|
||
)
|
||
})?,
|
||
Err(_) => {
|
||
return Err((
|
||
StatusCode::INTERNAL_SERVER_ERROR,
|
||
"Error mapping hotel_ids".to_string(),
|
||
));
|
||
}
|
||
};
|
||
|
||
let hotel_ids_json = match serde_json::to_string(&hotel_ids) {
|
||
Ok(json) => json,
|
||
Err(_) => {
|
||
return Err((
|
||
StatusCode::INTERNAL_SERVER_ERROR,
|
||
"Error mapping hotel_ids".to_string(),
|
||
));
|
||
}
|
||
};
|
||
|
||
/*.map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "Error mapping hotel_ids".to_string())); */
|
||
|
||
//FIXME: might not need the hotel list on tconflict ?
|
||
|
||
//TODO: remove user agent entirely from auth ,it is mutable and not stable
|
||
//TODO: make the token refresh on login
|
||
conn.execute(
|
||
r#"
|
||
INSERT INTO refresh_token (
|
||
user_id,
|
||
token_hash,
|
||
device_id,
|
||
hotel_id_list
|
||
)
|
||
VALUES (?1, ?2, ?3, ?4)
|
||
ON CONFLICT(user_id, device_id)
|
||
DO UPDATE SET
|
||
token_hash = excluded.token_hash,
|
||
hotel_id_list = excluded.hotel_id_list
|
||
"#,
|
||
params![user_id, hashed_token, device_id_str, hotel_ids_json],
|
||
)
|
||
.map_err(|e| {
|
||
(
|
||
StatusCode::INTERNAL_SERVER_ERROR,
|
||
format!("DB error: {}", e),
|
||
)
|
||
})?;
|
||
|
||
//TODO: add a map/tupple of of the allowed hotels and their id+name, maybe update the token ?
|
||
|
||
println!("RAW write refresh_token bytes: {:?}", &raw_token.as_bytes());
|
||
println!("RAW refresh_token : {}", &raw_token.to_string());
|
||
println!("RAW write refresh_token len: {}", &raw_token.len());
|
||
|
||
let cookie_value = format!(
|
||
"refresh_token={}; HttpOnly; Secure; Max-Age=60480000000;Path=/",
|
||
raw_token
|
||
);
|
||
|
||
let mut response = (StatusCode::CREATED, "Refresh token created successfully").into_response();
|
||
response
|
||
.headers_mut()
|
||
.insert(SET_COOKIE, HeaderValue::from_str(&cookie_value).unwrap());
|
||
|
||
Ok(response) // ← Wrap in Ok()
|
||
}
|
||
|
||
#[derive(Deserialize)]
|
||
pub struct LoginRefreshTokenValues {
|
||
device_id: Uuid,
|
||
//refresh_token: String,
|
||
}
|
||
|
||
//TODO: LATER : implement hotel-id-selected to allow user to only get part hotels ?
|
||
pub async fn login_refresh_token(
|
||
State(state): State<AppState>,
|
||
Extension(keys): Extension<JwtKeys>,
|
||
user_agent: Option<TypedHeader<UserAgent>>,
|
||
cookie_header: Option<TypedHeader<headers::Cookie>>,
|
||
Json(payload): Json<LoginRefreshTokenValues>,
|
||
) -> impl IntoResponse {
|
||
println!("login_refresh_token called");
|
||
// Log cookies
|
||
|
||
let cookies = match cookie_header {
|
||
Some(token) => token,
|
||
None => return (StatusCode::UNAUTHORIZED, "Missing refresh token cookie").into_response(),
|
||
};
|
||
|
||
let refresh_token = match cookies.get("refresh_token") {
|
||
Some(token) => token.to_string(),
|
||
None => return (StatusCode::UNAUTHORIZED, "Missing refresh token cookie").into_response(),
|
||
};
|
||
|
||
println!("RAW refresh_token bytes: {:?}", refresh_token.as_bytes());
|
||
println!("RAW refresh_token : {}", refresh_token.to_string());
|
||
println!("RAW refresh_token len: {}", refresh_token.len());
|
||
|
||
println!("Cookies: {:?}", &refresh_token);
|
||
|
||
let conn = match state.logs_pool.get() {
|
||
Ok(c) => c,
|
||
Err(_) => {
|
||
return (StatusCode::INTERNAL_SERVER_ERROR, "DB connection error").into_response();
|
||
}
|
||
};
|
||
|
||
let user_agent_str = match user_agent {
|
||
Some(ua) => ua.to_string(),
|
||
None => return (StatusCode::INTERNAL_SERVER_ERROR, "user agent unknown").into_response(),
|
||
};
|
||
println!("UA {:?}", &user_agent_str);
|
||
let device_id_str = payload.device_id.to_string();
|
||
|
||
println!("device id: {:?}", &device_id_str);
|
||
|
||
//"SELECT user_id, token_hash, hotel_id FROM refresh_token WHERE device_id = ?1 AND user_agent = ?2",
|
||
|
||
//TODO: swap to query row and get hotel-id's list and not single hotel per row
|
||
//deserializing the list :
|
||
//let hotel_ids: Vec<i32> = serde_json::from_str(&stored_value)?;
|
||
let mut stmt = match conn.prepare(
|
||
"SELECT user_id, hotel_id_list
|
||
FROM refresh_token
|
||
WHERE device_id = ?1 AND token_hash = ?2
|
||
LIMIT 1;",
|
||
) {
|
||
Ok(s) => s,
|
||
Err(_) => {
|
||
return (
|
||
StatusCode::INTERNAL_SERVER_ERROR,
|
||
"Error prepatring hotel_id_list stmt",
|
||
)
|
||
.into_response();
|
||
}
|
||
};
|
||
|
||
let rows = match stmt
|
||
.query_one(params![&device_id_str, &refresh_token], |row| {
|
||
Ok((
|
||
row.get::<_, i32>(0)?, // user_id
|
||
row.get::<_, String>(1)?, // token_hash
|
||
//row.get::<_, String>(2)?, // hotel_id //FIXME: this is supposed to be vectore maybe ?
|
||
))
|
||
})
|
||
.optional()
|
||
{
|
||
Ok(r) => r,
|
||
Err(e) => {
|
||
eprintln!("DB ERROR: {:?}", e);
|
||
return (
|
||
StatusCode::INTERNAL_SERVER_ERROR,
|
||
format!("DB query error: {}", e),
|
||
)
|
||
.into_response();
|
||
}
|
||
};
|
||
//TODO: extraction of the blob
|
||
//let json_hotel_ids = rows.2;
|
||
let (user_id, json_hotel_ids) = match rows {
|
||
Some(r) => r,
|
||
None => {
|
||
return (
|
||
StatusCode::UNAUTHORIZED,
|
||
"No refresh token found for this device",
|
||
)
|
||
.into_response();
|
||
}
|
||
};
|
||
let hotel_ids: Vec<i32> = match serde_json::from_str(&json_hotel_ids) {
|
||
Ok(ids) => ids,
|
||
Err(_) => {
|
||
return (
|
||
StatusCode::INTERNAL_SERVER_ERROR,
|
||
"Hotel ids are not deserializable to Vec",
|
||
)
|
||
.into_response();
|
||
}
|
||
};
|
||
|
||
//FIXME: still problems when corrupted token exist
|
||
if hotel_ids.is_empty() {
|
||
return (StatusCode::UNAUTHORIZED, "No matching device").into_response();
|
||
}
|
||
|
||
/*
|
||
|
||
eprintln!("DB ERROR: {:?}", &refresh_token);
|
||
eprintln!("DB ERROR: {:?}", &token);
|
||
|
||
|
||
//still not auto adding hotel user link when creating account
|
||
if (&refresh_token != &token) {
|
||
// skip rows with wrong hash
|
||
return (StatusCode::UNAUTHORIZED, "Invelid credentials").into_response();
|
||
}
|
||
*/
|
||
let expiration = match chrono::Utc::now().checked_add_signed(chrono::Duration::hours(15)) {
|
||
Some(time) => time.timestamp() as usize,
|
||
None => {
|
||
// Handle overflow — probably a 500, since this should never happen
|
||
return (
|
||
StatusCode::INTERNAL_SERVER_ERROR,
|
||
"Time overflow".to_string(),
|
||
)
|
||
.into_response();
|
||
}
|
||
};
|
||
|
||
let mut tokens = Vec::new();
|
||
|
||
for hotel_id in hotel_ids {
|
||
let claims = serde_json::json!({
|
||
"id": user_id,
|
||
"hotel_id": hotel_id,
|
||
"exp": expiration
|
||
});
|
||
|
||
let token = match encode(&Header::default(), &claims, &keys.encoding) {
|
||
Ok(token) => token,
|
||
Err(_) => {
|
||
return (StatusCode::INTERNAL_SERVER_ERROR, "JWT creation failed").into_response();
|
||
}
|
||
};
|
||
|
||
tokens.push(token);
|
||
}
|
||
|
||
if tokens.is_empty() {
|
||
return (StatusCode::UNAUTHORIZED, "Invalid or mismatched token").into_response();
|
||
}
|
||
|
||
//Json(tokens).into_response()
|
||
Json(MultiLoginResponse { user_id, tokens }).into_response()
|
||
}
|
||
|
||
#[axum::debug_handler]
|
||
pub async fn logout_from_single_device(
|
||
State(state): State<AppState>,
|
||
Extension(keys): Extension<JwtKeys>,
|
||
user_agent: Option<TypedHeader<UserAgent>>,
|
||
cookie_header: Option<TypedHeader<headers::Cookie>>,
|
||
Json(payload): Json<LoginRefreshTokenValues>,
|
||
) -> impl IntoResponse {
|
||
let user_agent_str = user_agent
|
||
.map(|TypedHeader(ua)| ua.as_str().to_owned())
|
||
.unwrap_or_else(|| "Unknown".to_string());
|
||
|
||
let cookies = match cookie_header {
|
||
Some(token) => token,
|
||
None => return (StatusCode::UNAUTHORIZED, "Missing refresh token cookie").into_response(),
|
||
};
|
||
|
||
let refresh_token = match cookies.get("refresh_token") {
|
||
Some(token) => token.to_string(),
|
||
None => return (StatusCode::UNAUTHORIZED, "Missing refresh token cookie").into_response(),
|
||
};
|
||
|
||
let device_id_str = payload.device_id.to_string();
|
||
|
||
let conn = match state.logs_pool.get() {
|
||
Ok(c) => c,
|
||
Err(_) => {
|
||
return (StatusCode::INTERNAL_SERVER_ERROR, "DB connection error").into_response();
|
||
}
|
||
};
|
||
|
||
let device_row = match conn
|
||
.query_row(
|
||
"SELECT user_id, hotel_id_list, id
|
||
FROM refresh_token
|
||
WHERE token_hash = ?1 AND revoked = 0 ",
|
||
params![&refresh_token],
|
||
|row| {
|
||
let user_id: i32 = row.get(0)?;
|
||
let json_hotel_id_list: String = row.get(1)?;
|
||
let id: i32 = row.get(2)?;
|
||
//let displayname: String = row.get(3)?;
|
||
Ok((user_id, json_hotel_id_list, id))
|
||
},
|
||
)
|
||
.optional()
|
||
{
|
||
Ok(opt) => opt,
|
||
Err(e) => {
|
||
return (
|
||
StatusCode::INTERNAL_SERVER_ERROR,
|
||
format!("DB query error : {}", e),
|
||
)
|
||
.into_response();
|
||
}
|
||
};
|
||
|
||
let (user_id, json_hotel_id_list, token_id) = match device_row {
|
||
Some(tuple) => tuple,
|
||
None => return (StatusCode::UNAUTHORIZED, "No matching device").into_response(),
|
||
};
|
||
|
||
let hotel_ids: Vec<i32> = match serde_json::from_str(&json_hotel_id_list) {
|
||
Ok(ids) => ids,
|
||
Err(_) => {
|
||
return (
|
||
StatusCode::INTERNAL_SERVER_ERROR,
|
||
"Hotel ids are not deserializable to Vec",
|
||
)
|
||
.into_response();
|
||
}
|
||
};
|
||
|
||
//FIXME: need to chang the way we get refresh token from the cookies instead
|
||
/*
|
||
if !verify_password(&payload.refresh_token, &token_hash) {
|
||
return (StatusCode::UNAUTHORIZED, "Invalid or mismatched token").into_response();
|
||
}
|
||
*/
|
||
|
||
let revoked: Result<String, rusqlite::Error> = conn.query_row(
|
||
"DELETE FROM refresh_token
|
||
WHERE id = ?1
|
||
RETURNING device_id",
|
||
params![&token_id],
|
||
|row| row.get(0),
|
||
);
|
||
|
||
let revoked_id = match (revoked) {
|
||
Ok(r) => r,
|
||
Err(e) => {
|
||
return (
|
||
StatusCode::INTERNAL_SERVER_ERROR,
|
||
"Hotel ids are not deserializable to Vec",
|
||
)
|
||
.into_response();
|
||
}
|
||
};
|
||
|
||
let cookie_value = format!(
|
||
"refresh_token={}; HttpOnly; Secure; Max-Age=0;Path=/",
|
||
"loggedout"
|
||
);
|
||
|
||
let mut response = (
|
||
StatusCode::CREATED,
|
||
format!("Token deleted for device id {}", &revoked_id),
|
||
)
|
||
.into_response();
|
||
|
||
response
|
||
.headers_mut()
|
||
.insert(SET_COOKIE, HeaderValue::from_str(&cookie_value).unwrap());
|
||
|
||
response
|
||
}
|
||
|
||
pub async fn logout_from_all_devices(
|
||
State(state): State<AppState>,
|
||
Extension(keys): Extension<JwtKeys>,
|
||
AuthClaims { user_id, hotel_id }: AuthClaims,
|
||
//Json(payload): Json<LoginRefreshTokenValues>
|
||
) -> impl IntoResponse {
|
||
//let device_id_str = payload.device_id.to_string();
|
||
|
||
let conn = match state.logs_pool.get() {
|
||
Ok(c) => c,
|
||
Err(_) => {
|
||
return (StatusCode::INTERNAL_SERVER_ERROR, "DB connection error").into_response();
|
||
}
|
||
};
|
||
|
||
let result = conn.execute(
|
||
"DELETE FROM refresh_token WHERE user_id = ?1",
|
||
params![&user_id],
|
||
);
|
||
|
||
/*
|
||
|
||
match result {
|
||
//Ok(count) if count > 0 => {
|
||
// (StatusCode::OK, format!("Revoked {} active tokens", count)).into_response()
|
||
//}
|
||
//Ok(_) => (StatusCode::NOT_FOUND, "No active tokens to revoke").into_response(),
|
||
Err(_) => (
|
||
StatusCode::INTERNAL_SERVER_ERROR,
|
||
"Database update error".to_string(),
|
||
)
|
||
.into_response(),
|
||
}
|
||
*/
|
||
|
||
let cookie_value = format!(
|
||
"refresh_token={}; HttpOnly; Secure; Max-Age=0;Path=/",
|
||
"loggedout"
|
||
);
|
||
|
||
let mut response =
|
||
(StatusCode::CREATED, format!("Token deleted for device id ")).into_response();
|
||
|
||
response
|
||
.headers_mut()
|
||
.insert(SET_COOKIE, HeaderValue::from_str(&cookie_value).unwrap());
|
||
|
||
match result {
|
||
//Ok(count) if count > 0 => {
|
||
// (StatusCode::OK, format!("Revoked {} active tokens", count)).into_response()
|
||
//}
|
||
Ok(_) => response,
|
||
Err(err) => (
|
||
StatusCode::INTERNAL_SERVER_ERROR,
|
||
err.to_string(), // or format!("{err:?}")
|
||
)
|
||
.into_response(),
|
||
}
|
||
|
||
//response
|
||
}
|
||
|
||
#[derive(Serialize)]
|
||
struct HotelData {
|
||
id: i32,
|
||
hotel_name: String,
|
||
}
|
||
|
||
pub async fn get_hotel(State(state): State<AppState>) -> impl IntoResponse {
|
||
let try_conn = state.logs_pool.get();
|
||
|
||
let conn = match try_conn {
|
||
Ok(conn) => conn,
|
||
Err(e) => return (StatusCode::INTERNAL_SERVER_ERROR, "bruh").into_response(),
|
||
};
|
||
|
||
let try_stmt = conn.prepare(
|
||
"
|
||
SELECT id, hotelname
|
||
FROM hotels",
|
||
);
|
||
|
||
let mut stmt = match try_stmt {
|
||
Ok(stmt) => stmt,
|
||
Err(e) => {
|
||
return (
|
||
StatusCode::INTERNAL_SERVER_ERROR,
|
||
"failed buildin statement",
|
||
)
|
||
.into_response();
|
||
}
|
||
};
|
||
|
||
let try_hotels = stmt.query_map(params![], |row| {
|
||
Ok(HotelData {
|
||
id: row.get(0)?,
|
||
hotel_name: row.get(1)?,
|
||
})
|
||
});
|
||
|
||
let hotel_itter = match try_hotels {
|
||
Ok(hotels) => hotels,
|
||
Err(e) => {
|
||
return (
|
||
StatusCode::INTERNAL_SERVER_ERROR,
|
||
"error processing hotel list",
|
||
)
|
||
.into_response();
|
||
}
|
||
};
|
||
|
||
let hotels: Vec<HotelData> = match hotel_itter.collect::<Result<Vec<_>, _>>() {
|
||
Ok(hotel) => hotel,
|
||
Err(e) => {
|
||
return (
|
||
StatusCode::INTERNAL_SERVER_ERROR,
|
||
format!("failed collection of hotel : {e}"),
|
||
)
|
||
.into_response();
|
||
}
|
||
};
|
||
|
||
match serde_json::to_string(&hotels) {
|
||
Ok(json) => return (StatusCode::OK, json).into_response(),
|
||
Err(e) => {
|
||
return (
|
||
StatusCode::INTERNAL_SERVER_ERROR,
|
||
format!("Serialization failed: {}", e),
|
||
)
|
||
.into_response();
|
||
}
|
||
};
|
||
|
||
//.map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "DB connection error".to_string()))?;
|
||
//return (StatusCode::OK).into_response();
|
||
}
|
||
|
||
#[derive(Deserialize, Debug)]
|
||
pub struct addHotelUser {
|
||
user_id: i32,
|
||
#[serde(default)]
|
||
hotel_ids: Vec<i32>,
|
||
}
|
||
|
||
pub async fn add_hotel_user(
|
||
State(state): State<AppState>,
|
||
Extension(keys): Extension<JwtKeys>,
|
||
Json(payload): Json<addHotelUser>,
|
||
) -> impl IntoResponse {
|
||
let conn = match state.logs_pool.get() {
|
||
Ok(c) => c,
|
||
Err(e) => {
|
||
return (StatusCode::INTERNAL_SERVER_ERROR, "DB connection error").into_response();
|
||
}
|
||
};
|
||
|
||
let user_name: String = match conn.query_row(
|
||
"SELECT username FROM users WHERE id = ?1",
|
||
params![&payload.user_id],
|
||
|row| row.get(0),
|
||
) {
|
||
Ok(name) => name,
|
||
Err(e) => {
|
||
return (
|
||
StatusCode::INTERNAL_SERVER_ERROR,
|
||
format!("user not found {e} "),
|
||
)
|
||
.into_response();
|
||
}
|
||
};
|
||
|
||
let mut get_hotel_name_stmt = match conn.prepare("SELECT hotelname FROM hotels WHERE id = ?1") {
|
||
Ok(stmt) => stmt,
|
||
Err(e) => {
|
||
return (
|
||
StatusCode::INTERNAL_SERVER_ERROR,
|
||
format!("could't prepare stmt for hotel : {e} "),
|
||
)
|
||
.into_response();
|
||
}
|
||
};
|
||
|
||
let mut insert_hotel_link_stmt = match conn.prepare(
|
||
"INSERT INTO hotel_user_link
|
||
(user_id,hotel_id,username,hotelname)
|
||
VALUES (?1,?2,?3,?4)",
|
||
) {
|
||
Ok(stmt) => stmt,
|
||
Err(e) => {
|
||
return (
|
||
StatusCode::INTERNAL_SERVER_ERROR,
|
||
format!("could't prepare stmt to insert hotel : {e} "),
|
||
)
|
||
.into_response();
|
||
}
|
||
};
|
||
|
||
for &hotel_id in &payload.hotel_ids {
|
||
let hotel_name: String =
|
||
match get_hotel_name_stmt.query_row(params![hotel_id], |row| row.get(0)) {
|
||
Ok(name) => name,
|
||
Err(e) => {
|
||
return (
|
||
StatusCode::INTERNAL_SERVER_ERROR,
|
||
format!("hotel not found {e} "),
|
||
)
|
||
.into_response();
|
||
}
|
||
};
|
||
|
||
let add_link = match conn.execute(
|
||
"INSERT INTO hotel_user_link
|
||
(user_id,hotel_id,username,hotelname)
|
||
VALUES (?1,?2,?3,?4)",
|
||
params![payload.user_id, hotel_id, user_name, hotel_name],
|
||
) {
|
||
Ok(_) => true,
|
||
Err(e) => {
|
||
return (
|
||
StatusCode::INTERNAL_SERVER_ERROR,
|
||
format!("hotel not found {e} "),
|
||
)
|
||
.into_response();
|
||
}
|
||
};
|
||
|
||
//TODO: still need to build the add hotel to user here
|
||
}
|
||
|
||
return (StatusCode::OK, "goo").into_response();
|
||
}
|
||
|
||
fn internal_error<E: std::fmt::Display>(err: E) -> (StatusCode, String) {
|
||
(
|
||
StatusCode::INTERNAL_SERVER_ERROR,
|
||
format!("Internal error: {}", err),
|
||
)
|
||
}
|