diff options
author | Santo Cariotti <santo@dcariotti.me> | 2024-08-21 17:34:58 +0200 |
---|---|---|
committer | Santo Cariotti <santo@dcariotti.me> | 2024-08-21 17:34:58 +0200 |
commit | a76200bb7adb6189d84e0e98d6233a470ebeee98 (patch) | |
tree | f0c1dd5aba99ff33d01a5780f695cf65a9dcc411 | |
parent | a92fb07d23fb2268a6f4e650c5cbd00ad993e760 (diff) |
Authentication for endpoints
-rw-r--r-- | Cargo.lock | 48 | ||||
-rw-r--r-- | Cargo.toml | 1 | ||||
-rw-r--r-- | src/errors.rs | 5 | ||||
-rw-r--r-- | src/graphql/query.rs | 2 | ||||
-rw-r--r-- | src/graphql/routes.rs | 9 | ||||
-rw-r--r-- | src/graphql/types/jwt.rs | 55 | ||||
-rw-r--r-- | src/graphql/types/user.rs | 41 | ||||
-rw-r--r-- | src/main.rs | 4 | ||||
-rw-r--r-- | src/state.rs | 1 |
9 files changed, 133 insertions, 33 deletions
@@ -275,6 +275,29 @@ dependencies = [ ] [[package]] +name = "axum-extra" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0be6ea09c9b96cb5076af0de2e383bd2bc0c18f827cf1967bdd353e0b910d733" +dependencies = [ + "axum", + "axum-core", + "bytes", + "futures-util", + "headers", + "http", + "http-body", + "http-body-util", + "mime", + "pin-project-lite", + "serde", + "tower", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] name = "backtrace" version = "0.3.73" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -374,6 +397,7 @@ dependencies = [ "async-graphql", "async-graphql-axum", "axum", + "axum-extra", "chrono", "config", "futures-util", @@ -746,6 +770,30 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" [[package]] +name = "headers" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "322106e6bd0cba2d5ead589ddb8150a13d7c4217cf80d7c4f682ca994ccc6aa9" +dependencies = [ + "base64 0.21.7", + "bytes", + "headers-core", + "http", + "httpdate", + "mime", + "sha1", +] + +[[package]] +name = "headers-core" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "54b4a22553d4242c49fddb9ba998a99962b5cc6f22cb5a3482bec22522403ce4" +dependencies = [ + "http", +] + +[[package]] name = "heck" version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -22,3 +22,4 @@ jsonwebtoken = "9.3.0" once_cell = "1.19.0" chrono = "0.4.38" sha256 = "1.5.0" +axum-extra = { version = "0.9.3", features = ["typed-header"] } diff --git a/src/errors.rs b/src/errors.rs index fafe0b0..1b9a802 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -42,10 +42,7 @@ impl IntoResponse for AppError { "Token creation error".to_string(), ), AppError::InvalidToken => (StatusCode::BAD_REQUEST, "Invalid token".to_string()), - AppError::Unauthorized => ( - StatusCode::UNAUTHORIZED, - "Can't perform this action".to_string(), - ), + AppError::Unauthorized => (StatusCode::UNAUTHORIZED, "Unauthorized".to_string()), }; let body = Json(json!({ diff --git a/src/graphql/query.rs b/src/graphql/query.rs index ed83a9f..254dab6 100644 --- a/src/graphql/query.rs +++ b/src/graphql/query.rs @@ -1,4 +1,4 @@ -use crate::graphql::types::user; +use crate::{errors::AppError, graphql::types::user}; use async_graphql::{Context, Object}; pub struct Query; diff --git a/src/graphql/routes.rs b/src/graphql/routes.rs index e15267e..a566c65 100644 --- a/src/graphql/routes.rs +++ b/src/graphql/routes.rs @@ -2,11 +2,14 @@ use crate::graphql::mutation::Mutation; use crate::graphql::query::Query; use async_graphql::{EmptySubscription, Schema}; use async_graphql_axum::{GraphQLRequest, GraphQLResponse}; -use std::sync::Arc; +use axum::extract::Extension; + +use super::types::jwt::Authentication; pub async fn graphql_handler( - schema: Arc<Schema<Query, Mutation, EmptySubscription>>, + schema: Extension<Schema<Query, Mutation, EmptySubscription>>, + auth: Authentication, req: GraphQLRequest, ) -> GraphQLResponse { - schema.execute(req.into_inner()).await.into() + schema.execute(req.0.data(auth)).await.into() } diff --git a/src/graphql/types/jwt.rs b/src/graphql/types/jwt.rs index 932f7fd..c118622 100644 --- a/src/graphql/types/jwt.rs +++ b/src/graphql/types/jwt.rs @@ -1,5 +1,10 @@ use crate::errors::AppError; use async_graphql::{InputObject, SimpleObject}; +use axum::{async_trait, extract::FromRequestParts, http::request::Parts}; +use axum_extra::{ + headers::{authorization::Bearer, Authorization}, + typed_header::TypedHeader, +}; use chrono::{Duration, Local}; use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation}; use once_cell::sync::Lazy; @@ -7,7 +12,7 @@ use serde::{Deserialize, Serialize}; struct Keys { encoding: EncodingKey, - _decoding: DecodingKey, + decoding: DecodingKey, } static KEYS: Lazy<Keys> = Lazy::new(|| { @@ -19,20 +24,25 @@ impl Keys { fn new(secret: &[u8]) -> Self { Self { encoding: EncodingKey::from_secret(secret), - _decoding: DecodingKey::from_secret(secret), + decoding: DecodingKey::from_secret(secret), } } } -/// Claims struct -#[derive(Serialize, Deserialize)] +/// Claims struct. +#[derive(Serialize, Deserialize, Debug, Clone)] pub struct Claims { - /// ID from the user model - pub user_id: i32, - /// Expiration timestamp + user_id: i32, exp: usize, } +/// Authentication enum +#[derive(Serialize, Deserialize, Debug, Clone)] +pub enum Authentication { + Logged(Claims), + NotLogged, +} + impl Claims { /// Create a new Claim using the `user_id` and the current timestamp + 2 days pub fn new(user_id: i32) -> Self { @@ -77,3 +87,34 @@ impl AuthBody { } } } + +#[async_trait] +impl<S> FromRequestParts<S> for Authentication +where + S: Send + Sync, +{ + type Rejection = AppError; + + async fn from_request_parts(parts: &mut Parts, _: &S) -> Result<Self, Self::Rejection> { + // Extract the Authorization header + match TypedHeader::<Authorization<Bearer>>::from_request_parts(parts, &()).await { + Ok(TypedHeader(Authorization(bearer))) => { + // Decode the token + let token_data = + decode::<Claims>(bearer.token(), &KEYS.decoding, &Validation::default()) + .map_err(|err| match err.kind() { + jsonwebtoken::errors::ErrorKind::ExpiredSignature => { + AppError::InvalidToken + } + _ => { + eprintln!("{err:?}"); + return AppError::Unauthorized; + } + })?; + + Ok(Self::Logged(token_data.claims)) + } + Err(_) => Ok(Self::NotLogged), + } + } +} diff --git a/src/graphql/types/user.rs b/src/graphql/types/user.rs index bf9080f..b675f1f 100644 --- a/src/graphql/types/user.rs +++ b/src/graphql/types/user.rs @@ -2,6 +2,8 @@ use crate::state::AppState; use async_graphql::{Context, Object}; use serde::{Deserialize, Serialize}; +use super::jwt::Authentication; + #[derive(Clone, Debug, Serialize, Deserialize)] pub struct User { pub id: i32, @@ -32,21 +34,26 @@ impl User { pub async fn get_users<'ctx>(ctx: &Context<'ctx>) -> Result<Option<Vec<User>>, String> { let state = ctx.data::<AppState>().expect("Can't connect to db"); let client = &*state.client; - - let rows = client - .query("SELECT id, email, password, is_admin FROM users", &[]) - .await - .unwrap(); - - let users: Vec<User> = rows - .iter() - .map(|row| User { - id: row.get("id"), - email: row.get("email"), - password: row.get("password"), - is_admin: row.get("is_admin"), - }) - .collect(); - - Ok(Some(users)) + let auth: &Authentication = ctx.data().unwrap(); + match auth { + Authentication::NotLogged => Err("Unauthorized".to_string()), + Authentication::Logged(_claims) => { + let rows = client + .query("SELECT id, email, password, is_admin FROM users", &[]) + .await + .unwrap(); + + let users: Vec<User> = rows + .iter() + .map(|row| User { + id: row.get("id"), + email: row.get("email"), + password: row.get("password"), + is_admin: row.get("is_admin"), + }) + .collect(); + + Ok(Some(users)) + } + } } diff --git a/src/main.rs b/src/main.rs index 04e2564..8dfc145 100644 --- a/src/main.rs +++ b/src/main.rs @@ -26,6 +26,7 @@ use tracing::Span; async fn create_app() -> Router { logger::setup(); let dbclient = db::setup().await.unwrap(); + let state = state::AppState { client: Arc::new(dbclient), }; @@ -37,10 +38,11 @@ async fn create_app() -> Router { ) .data(state.clone()) .finish(); + Router::new() .route( "/graphql", - post(move |req| graphql::routes::graphql_handler(schema.clone().into(), req)), + post(graphql::routes::graphql_handler).layer(Extension(schema.clone())), ) .fallback(crate::routes::page_404) // Mark the `Authorization` request header as sensitive so it doesn't diff --git a/src/state.rs b/src/state.rs index d4719b9..876a5f4 100644 --- a/src/state.rs +++ b/src/state.rs @@ -1,6 +1,7 @@ use std::sync::Arc; use tokio_postgres::Client; + #[derive(Clone)] pub struct AppState { pub client: Arc<Client>, |