summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Cargo.lock48
-rw-r--r--Cargo.toml1
-rw-r--r--src/errors.rs5
-rw-r--r--src/graphql/query.rs2
-rw-r--r--src/graphql/routes.rs9
-rw-r--r--src/graphql/types/jwt.rs55
-rw-r--r--src/graphql/types/user.rs41
-rw-r--r--src/main.rs4
-rw-r--r--src/state.rs1
9 files changed, 133 insertions, 33 deletions
diff --git a/Cargo.lock b/Cargo.lock
index 558b2a8..df102b6 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -275,6 +275,29 @@ dependencies = [
]
[[package]]
+name = "axum-extra"
+version = "0.9.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "0be6ea09c9b96cb5076af0de2e383bd2bc0c18f827cf1967bdd353e0b910d733"
+dependencies = [
+ "axum",
+ "axum-core",
+ "bytes",
+ "futures-util",
+ "headers",
+ "http",
+ "http-body",
+ "http-body-util",
+ "mime",
+ "pin-project-lite",
+ "serde",
+ "tower",
+ "tower-layer",
+ "tower-service",
+ "tracing",
+]
+
+[[package]]
name = "backtrace"
version = "0.3.73"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -374,6 +397,7 @@ dependencies = [
"async-graphql",
"async-graphql-axum",
"axum",
+ "axum-extra",
"chrono",
"config",
"futures-util",
@@ -746,6 +770,30 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1"
[[package]]
+name = "headers"
+version = "0.4.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "322106e6bd0cba2d5ead589ddb8150a13d7c4217cf80d7c4f682ca994ccc6aa9"
+dependencies = [
+ "base64 0.21.7",
+ "bytes",
+ "headers-core",
+ "http",
+ "httpdate",
+ "mime",
+ "sha1",
+]
+
+[[package]]
+name = "headers-core"
+version = "0.3.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "54b4a22553d4242c49fddb9ba998a99962b5cc6f22cb5a3482bec22522403ce4"
+dependencies = [
+ "http",
+]
+
+[[package]]
name = "heck"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
diff --git a/Cargo.toml b/Cargo.toml
index 31bc110..8391d64 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -22,3 +22,4 @@ jsonwebtoken = "9.3.0"
once_cell = "1.19.0"
chrono = "0.4.38"
sha256 = "1.5.0"
+axum-extra = { version = "0.9.3", features = ["typed-header"] }
diff --git a/src/errors.rs b/src/errors.rs
index fafe0b0..1b9a802 100644
--- a/src/errors.rs
+++ b/src/errors.rs
@@ -42,10 +42,7 @@ impl IntoResponse for AppError {
"Token creation error".to_string(),
),
AppError::InvalidToken => (StatusCode::BAD_REQUEST, "Invalid token".to_string()),
- AppError::Unauthorized => (
- StatusCode::UNAUTHORIZED,
- "Can't perform this action".to_string(),
- ),
+ AppError::Unauthorized => (StatusCode::UNAUTHORIZED, "Unauthorized".to_string()),
};
let body = Json(json!({
diff --git a/src/graphql/query.rs b/src/graphql/query.rs
index ed83a9f..254dab6 100644
--- a/src/graphql/query.rs
+++ b/src/graphql/query.rs
@@ -1,4 +1,4 @@
-use crate::graphql::types::user;
+use crate::{errors::AppError, graphql::types::user};
use async_graphql::{Context, Object};
pub struct Query;
diff --git a/src/graphql/routes.rs b/src/graphql/routes.rs
index e15267e..a566c65 100644
--- a/src/graphql/routes.rs
+++ b/src/graphql/routes.rs
@@ -2,11 +2,14 @@ use crate::graphql::mutation::Mutation;
use crate::graphql::query::Query;
use async_graphql::{EmptySubscription, Schema};
use async_graphql_axum::{GraphQLRequest, GraphQLResponse};
-use std::sync::Arc;
+use axum::extract::Extension;
+
+use super::types::jwt::Authentication;
pub async fn graphql_handler(
- schema: Arc<Schema<Query, Mutation, EmptySubscription>>,
+ schema: Extension<Schema<Query, Mutation, EmptySubscription>>,
+ auth: Authentication,
req: GraphQLRequest,
) -> GraphQLResponse {
- schema.execute(req.into_inner()).await.into()
+ schema.execute(req.0.data(auth)).await.into()
}
diff --git a/src/graphql/types/jwt.rs b/src/graphql/types/jwt.rs
index 932f7fd..c118622 100644
--- a/src/graphql/types/jwt.rs
+++ b/src/graphql/types/jwt.rs
@@ -1,5 +1,10 @@
use crate::errors::AppError;
use async_graphql::{InputObject, SimpleObject};
+use axum::{async_trait, extract::FromRequestParts, http::request::Parts};
+use axum_extra::{
+ headers::{authorization::Bearer, Authorization},
+ typed_header::TypedHeader,
+};
use chrono::{Duration, Local};
use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation};
use once_cell::sync::Lazy;
@@ -7,7 +12,7 @@ use serde::{Deserialize, Serialize};
struct Keys {
encoding: EncodingKey,
- _decoding: DecodingKey,
+ decoding: DecodingKey,
}
static KEYS: Lazy<Keys> = Lazy::new(|| {
@@ -19,20 +24,25 @@ impl Keys {
fn new(secret: &[u8]) -> Self {
Self {
encoding: EncodingKey::from_secret(secret),
- _decoding: DecodingKey::from_secret(secret),
+ decoding: DecodingKey::from_secret(secret),
}
}
}
-/// Claims struct
-#[derive(Serialize, Deserialize)]
+/// Claims struct.
+#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Claims {
- /// ID from the user model
- pub user_id: i32,
- /// Expiration timestamp
+ user_id: i32,
exp: usize,
}
+/// Authentication enum
+#[derive(Serialize, Deserialize, Debug, Clone)]
+pub enum Authentication {
+ Logged(Claims),
+ NotLogged,
+}
+
impl Claims {
/// Create a new Claim using the `user_id` and the current timestamp + 2 days
pub fn new(user_id: i32) -> Self {
@@ -77,3 +87,34 @@ impl AuthBody {
}
}
}
+
+#[async_trait]
+impl<S> FromRequestParts<S> for Authentication
+where
+ S: Send + Sync,
+{
+ type Rejection = AppError;
+
+ async fn from_request_parts(parts: &mut Parts, _: &S) -> Result<Self, Self::Rejection> {
+ // Extract the Authorization header
+ match TypedHeader::<Authorization<Bearer>>::from_request_parts(parts, &()).await {
+ Ok(TypedHeader(Authorization(bearer))) => {
+ // Decode the token
+ let token_data =
+ decode::<Claims>(bearer.token(), &KEYS.decoding, &Validation::default())
+ .map_err(|err| match err.kind() {
+ jsonwebtoken::errors::ErrorKind::ExpiredSignature => {
+ AppError::InvalidToken
+ }
+ _ => {
+ eprintln!("{err:?}");
+ return AppError::Unauthorized;
+ }
+ })?;
+
+ Ok(Self::Logged(token_data.claims))
+ }
+ Err(_) => Ok(Self::NotLogged),
+ }
+ }
+}
diff --git a/src/graphql/types/user.rs b/src/graphql/types/user.rs
index bf9080f..b675f1f 100644
--- a/src/graphql/types/user.rs
+++ b/src/graphql/types/user.rs
@@ -2,6 +2,8 @@ use crate::state::AppState;
use async_graphql::{Context, Object};
use serde::{Deserialize, Serialize};
+use super::jwt::Authentication;
+
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct User {
pub id: i32,
@@ -32,21 +34,26 @@ impl User {
pub async fn get_users<'ctx>(ctx: &Context<'ctx>) -> Result<Option<Vec<User>>, String> {
let state = ctx.data::<AppState>().expect("Can't connect to db");
let client = &*state.client;
-
- let rows = client
- .query("SELECT id, email, password, is_admin FROM users", &[])
- .await
- .unwrap();
-
- let users: Vec<User> = rows
- .iter()
- .map(|row| User {
- id: row.get("id"),
- email: row.get("email"),
- password: row.get("password"),
- is_admin: row.get("is_admin"),
- })
- .collect();
-
- Ok(Some(users))
+ let auth: &Authentication = ctx.data().unwrap();
+ match auth {
+ Authentication::NotLogged => Err("Unauthorized".to_string()),
+ Authentication::Logged(_claims) => {
+ let rows = client
+ .query("SELECT id, email, password, is_admin FROM users", &[])
+ .await
+ .unwrap();
+
+ let users: Vec<User> = rows
+ .iter()
+ .map(|row| User {
+ id: row.get("id"),
+ email: row.get("email"),
+ password: row.get("password"),
+ is_admin: row.get("is_admin"),
+ })
+ .collect();
+
+ Ok(Some(users))
+ }
+ }
}
diff --git a/src/main.rs b/src/main.rs
index 04e2564..8dfc145 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -26,6 +26,7 @@ use tracing::Span;
async fn create_app() -> Router {
logger::setup();
let dbclient = db::setup().await.unwrap();
+
let state = state::AppState {
client: Arc::new(dbclient),
};
@@ -37,10 +38,11 @@ async fn create_app() -> Router {
)
.data(state.clone())
.finish();
+
Router::new()
.route(
"/graphql",
- post(move |req| graphql::routes::graphql_handler(schema.clone().into(), req)),
+ post(graphql::routes::graphql_handler).layer(Extension(schema.clone())),
)
.fallback(crate::routes::page_404)
// Mark the `Authorization` request header as sensitive so it doesn't
diff --git a/src/state.rs b/src/state.rs
index d4719b9..876a5f4 100644
--- a/src/state.rs
+++ b/src/state.rs
@@ -1,6 +1,7 @@
use std::sync::Arc;
use tokio_postgres::Client;
+
#[derive(Clone)]
pub struct AppState {
pub client: Arc<Client>,