diff options
Diffstat (limited to 'api')
-rw-r--r-- | api/auth/auth.go | 48 | ||||
-rw-r--r-- | api/auth/auth_test.go | 74 | ||||
-rw-r--r-- | api/database/database.go | 32 | ||||
-rw-r--r-- | api/database/models.go | 11 | ||||
-rw-r--r-- | api/handlers/handlers.go | 68 | ||||
-rw-r--r-- | api/middleware/middleware.go | 36 |
6 files changed, 269 insertions, 0 deletions
diff --git a/api/auth/auth.go b/api/auth/auth.go new file mode 100644 index 0000000..23b4f53 --- /dev/null +++ b/api/auth/auth.go @@ -0,0 +1,48 @@ +package auth + +import ( + "github.com/golang-jwt/jwt/v5" + "os" + "time" +) + +var jwtKey = []byte(os.Getenv("JWT_SECRET")) + +type Claims struct { + UserID int `json:"user_id"` + jwt.RegisteredClaims +} + +func GenerateJWT(userID int) (string, error) { + expirationTime := time.Now().Add(5 * time.Hour) + claims := &Claims{ + UserID: userID, + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(expirationTime), + }, + } + + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + tokenString, err := token.SignedString(jwtKey) + if err != nil { + return "", err + } + return tokenString, nil +} + +func ValidateJWT(tokenString string) (*Claims, error) { + claims := &Claims{} + token, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (interface{}, error) { + return jwtKey, nil + }) + + if err != nil { + return nil, err + } + + if !token.Valid { + return nil, err + } + + return claims, nil +} diff --git a/api/auth/auth_test.go b/api/auth/auth_test.go new file mode 100644 index 0000000..50b6c9b --- /dev/null +++ b/api/auth/auth_test.go @@ -0,0 +1,74 @@ +package auth + +import ( + "os" + "testing" + "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/assert" +) + +func TestGenerateAndValidateJWT(t *testing.T) { + // Set up the JWT secret for the test. + os.Setenv("JWT_SECRET", "testsecret") + jwtKey = []byte(os.Getenv("JWT_SECRET")) + + userID := 123 + tokenString, err := GenerateJWT(userID) + assert.NoError(t, err) + assert.NotEmpty(t, tokenString) + + claims, err := ValidateJWT(tokenString) + assert.NoError(t, err) + assert.NotNil(t, claims) + assert.Equal(t, userID, claims.UserID) + assert.True(t, claims.ExpiresAt.After(time.Now())) +} + +func TestValidateJWT_InvalidToken(t *testing.T) { + os.Setenv("JWT_SECRET", "testsecret") + jwtKey = []byte(os.Getenv("JWT_SECRET")) + + _, err := ValidateJWT("invalid_token") + assert.Error(t, err) +} + +func TestValidateJWT_ExpiredToken(t *testing.T) { + os.Setenv("JWT_SECRET", "testsecret") + jwtKey = []byte(os.Getenv("JWT_SECRET")) + + // Create a token that has already expired. + claims := &Claims{ + UserID: 123, + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(-1 * time.Hour)), + }, + } + + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + tokenString, err := token.SignedString(jwtKey) + assert.NoError(t, err) + + _, err = ValidateJWT(tokenString) + assert.Error(t, err) +} + +func TestValidateJWT_WrongSecret(t *testing.T) { + os.Setenv("JWT_SECRET", "testsecret") + jwtKey = []byte(os.Getenv("JWT_SECRET")) + + userID := 123 + tokenString, err := GenerateJWT(userID) + assert.NoError(t, err) + + // Set a different secret for validation. + wrongKey := []byte("wrongsecret") + + claims := &Claims{} + _, err = jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (interface{}, error) { + return wrongKey, nil + }) + + assert.Error(t, err) +} diff --git a/api/database/database.go b/api/database/database.go new file mode 100644 index 0000000..e5ecca8 --- /dev/null +++ b/api/database/database.go @@ -0,0 +1,32 @@ +package database + +import ( + "gorm.io/driver/postgres" + "gorm.io/gorm" + + "errors" +) + +// Global variable but private +var db *gorm.DB = nil + +// Init the database from a DSN string which must be a valid PostgreSQL dsn. +// Also, auto migrate all the models. +func InitDb(dsn string) (*gorm.DB, error) { + var err error + db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{}) + + if err == nil { + db.AutoMigrate(&User{}) + } + + return db, err +} + +// Return the instance or error if the config is not laoded yet +func GetDb() (*gorm.DB, error) { + if db == nil { + return nil, errors.New("You must call `InitDb()` first.") + } + return db, nil +} diff --git a/api/database/models.go b/api/database/models.go new file mode 100644 index 0000000..e309a36 --- /dev/null +++ b/api/database/models.go @@ -0,0 +1,11 @@ +package database + +import "time" + +type User struct { + ID int `json:"id"` + Username string `json:"username"` + Password string `json:"password"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} diff --git a/api/handlers/handlers.go b/api/handlers/handlers.go new file mode 100644 index 0000000..7d5fd10 --- /dev/null +++ b/api/handlers/handlers.go @@ -0,0 +1,68 @@ +package handlers + +import ( + "encoding/json" + "net/http" + + "github.com/boozec/rahanna/api/auth" + "github.com/boozec/rahanna/api/database" + "golang.org/x/crypto/bcrypt" +) + +func RegisterUser(w http.ResponseWriter, r *http.Request) { + var user database.User + err := json.NewDecoder(r.Body).Decode(&user) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + hashedPassword, err := bcrypt.GenerateFromPassword([]byte(user.Password), bcrypt.DefaultCost) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + user.Password = string(hashedPassword) + + db, _ := database.GetDb() + + result := db.Create(&user) + if result.Error != nil { + http.Error(w, result.Error.Error(), http.StatusInternalServerError) + return + } + + w.WriteHeader(http.StatusCreated) +} + +func LoginUser(w http.ResponseWriter, r *http.Request) { + var inputUser database.User + err := json.NewDecoder(r.Body).Decode(&inputUser) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + var storedUser database.User + + db, _ := database.GetDb() + result := db.Where("username = ?", inputUser.Username).First(&storedUser) + if result.Error != nil { + http.Error(w, "Invalid credentials", http.StatusUnauthorized) + return + } + + err = bcrypt.CompareHashAndPassword([]byte(storedUser.Password), []byte(inputUser.Password)) + if err != nil { + http.Error(w, "Invalid credentials", http.StatusUnauthorized) + return + } + + token, err := auth.GenerateJWT(storedUser.ID) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + json.NewEncoder(w).Encode(map[string]string{"token": token}) +} diff --git a/api/middleware/middleware.go b/api/middleware/middleware.go new file mode 100644 index 0000000..29ed8b6 --- /dev/null +++ b/api/middleware/middleware.go @@ -0,0 +1,36 @@ +package middleware + +import ( + "encoding/json" + "net/http" + + "github.com/boozec/rahanna/api/auth" +) + +func AuthMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + tokenString := r.Header.Get("Authorization") + + payloadMap := map[string]string{"error": "unauthorized"} + payload, _ := json.Marshal(payloadMap) + + if tokenString == "" { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusUnauthorized) + + w.Write([]byte(payload)) + return + } + + _, err := auth.ValidateJWT(tokenString) + if err != nil { + w.WriteHeader(http.StatusUnauthorized) + + payload, _ := json.Marshal(payloadMap) + + w.Write([]byte(payload)) + return + } + next.ServeHTTP(w, r) + }) +} |