diff options
Diffstat (limited to 'internal')
-rw-r--r-- | internal/api/auth/auth.go | 24 | ||||
-rw-r--r-- | internal/api/auth/auth_test.go | 2 | ||||
-rw-r--r-- | internal/api/handlers/handlers.go | 104 | ||||
-rw-r--r-- | internal/api/middleware/middleware.go | 11 | ||||
-rw-r--r-- | internal/logger/logger.go | 32 | ||||
-rw-r--r-- | internal/network/ip.go | 2 |
6 files changed, 91 insertions, 84 deletions
diff --git a/internal/api/auth/auth.go b/internal/api/auth/auth.go index b382beb..966a09c 100644 --- a/internal/api/auth/auth.go +++ b/internal/api/auth/auth.go @@ -7,17 +7,26 @@ import ( "time" "github.com/golang-jwt/jwt/v5" + "gorm.io/gorm" ) +// Key used for JWT encryption/decryption var jwtKey = []byte(os.Getenv("JWT_SECRET")) +// Kind of JWT token +var TokenType = "Bearer" + +// Extends JWT Claims with the UserID field type Claims struct { UserID int `json:"user_id"` jwt.RegisteredClaims } +// Generate a JWT token from an userID. func GenerateJWT(userID int) (string, error) { - expirationTime := time.Now().Add(5 * time.Hour) + // Set expiration date for the token to 90 days + expirationTime := time.Now().Add(90 * 24 * time.Hour) + claims := &Claims{ UserID: userID, RegisteredClaims: jwt.RegisteredClaims{ @@ -30,17 +39,21 @@ func GenerateJWT(userID int) (string, error) { if err != nil { return "", err } - return tokenString, nil + return TokenType + " " + tokenString, nil } +// Validate a JWT token for a kind of time func ValidateJWT(tokenString string) (*Claims, error) { claims := &Claims{} - // A token has a form `Bearer ...` tokenParts := strings.Split(tokenString, " ") if len(tokenParts) != 2 { return nil, errors.New("not valid JWT") } + if tokenParts[0] != TokenType { + return nil, errors.New("not valid JWT type") + } + token, err := jwt.ParseWithClaims(tokenParts[1], claims, func(token *jwt.Token) (interface{}, error) { return jwtKey, nil }) @@ -55,3 +68,8 @@ func ValidateJWT(tokenString string) (*Claims, error) { return claims, nil } + +// Common omit password field for users +func OmitPassword(db *gorm.DB) *gorm.DB { + return db.Omit("Password") +} diff --git a/internal/api/auth/auth_test.go b/internal/api/auth/auth_test.go index 66bcc27..50b6c9b 100644 --- a/internal/api/auth/auth_test.go +++ b/internal/api/auth/auth_test.go @@ -19,8 +19,6 @@ func TestGenerateAndValidateJWT(t *testing.T) { assert.NoError(t, err) assert.NotEmpty(t, tokenString) - tokenString = "Bearer " + tokenString - claims, err := ValidateJWT(tokenString) assert.NoError(t, err) assert.NotNil(t, claims) diff --git a/internal/api/handlers/handlers.go b/internal/api/handlers/handlers.go index 41779c7..6d1b4e3 100644 --- a/internal/api/handlers/handlers.go +++ b/internal/api/handlers/handlers.go @@ -21,6 +21,7 @@ type NewGameRequest struct { func RegisterUser(w http.ResponseWriter, r *http.Request) { log, _ := logger.GetLogger() log.Info("POST /auth/register") + var user database.User err := json.NewDecoder(r.Body).Decode(&user) if err != nil { @@ -35,9 +36,7 @@ func RegisterUser(w http.ResponseWriter, r *http.Request) { var storedUser database.User db, _ := database.GetDb() - result := db.Where("username = ?", user.Username).First(&storedUser) - - if result.Error == nil { + if result := db.Where("username = ?", user.Username).First(&storedUser); result.Error == nil { JsonError(&w, "user with this username already exists") return } @@ -49,8 +48,7 @@ func RegisterUser(w http.ResponseWriter, r *http.Request) { } user.Password = string(hashedPassword) - result = db.Create(&user) - if result.Error != nil { + if result := db.Create(&user); result.Error != nil { JsonError(&w, result.Error.Error()) return } @@ -67,6 +65,7 @@ func RegisterUser(w http.ResponseWriter, r *http.Request) { func LoginUser(w http.ResponseWriter, r *http.Request) { log, _ := logger.GetLogger() log.Info("POST /auth/login") + var inputUser database.User err := json.NewDecoder(r.Body).Decode(&inputUser) if err != nil { @@ -77,8 +76,7 @@ func LoginUser(w http.ResponseWriter, r *http.Request) { var storedUser database.User db, _ := database.GetDb() - result := db.Where("username = ?", inputUser.Username).First(&storedUser) - if result.Error != nil { + if result := db.Where("username = ?", inputUser.Username).First(&storedUser); result.Error != nil { JsonError(&w, "invalid credentials") return } @@ -100,10 +98,10 @@ func LoginUser(w http.ResponseWriter, r *http.Request) { func NewPlay(w http.ResponseWriter, r *http.Request) { log, _ := logger.GetLogger() log.Info("POST /play") - claims, err := auth.ValidateJWT(r.Header.Get("Authorization")) - if err != nil { - JsonError(&w, err.Error()) + claims, ok := r.Context().Value("claims").(*auth.Claims) + if !ok { + JsonError(&w, "claims not found") return } @@ -116,11 +114,6 @@ func NewPlay(w http.ResponseWriter, r *http.Request) { return } - if err != nil { - JsonError(&w, err.Error()) - return - } - db, _ := database.GetDb() name := network.NewSession() @@ -133,8 +126,7 @@ func NewPlay(w http.ResponseWriter, r *http.Request) { Outcome: "*", } - result := db.Create(&play) - if result.Error != nil { + if result := db.Create(&play); result.Error != nil { JsonError(&w, result.Error.Error()) return } @@ -145,10 +137,10 @@ func NewPlay(w http.ResponseWriter, r *http.Request) { func EnterGame(w http.ResponseWriter, r *http.Request) { log, _ := logger.GetLogger() log.Info("POST /enter-game") - claims, err := auth.ValidateJWT(r.Header.Get("Authorization")) - if err != nil { - JsonError(&w, err.Error()) + claims, ok := r.Context().Value("claims").(*auth.Claims) + if !ok { + JsonError(&w, "claims not found") return } @@ -162,17 +154,11 @@ func EnterGame(w http.ResponseWriter, r *http.Request) { return } - if err != nil { - JsonError(&w, err.Error()) - return - } - db, _ := database.GetDb() var game database.Game - result := db.Where("name = ? AND player2_id IS NULL", payload.Name).First(&game) - if result.Error != nil { + if result := db.Where("name = ? AND player2_id IS NULL", payload.Name).First(&game); result.Error != nil { JsonError(&w, result.Error.Error()) return } @@ -186,13 +172,9 @@ func EnterGame(w http.ResponseWriter, r *http.Request) { return } - result = db.Where("id = ?", game.ID). - Preload("Player1", func(db *gorm.DB) *gorm.DB { - return db.Omit("Password") - }). - Preload("Player2", func(db *gorm.DB) *gorm.DB { - return db.Omit("Password") - }). + result := db.Where("id = ?", game.ID). + Preload("Player1", auth.OmitPassword). + Preload("Player2", auth.OmitPassword). First(&game) if result.Error != nil { @@ -207,17 +189,16 @@ func AllPlay(w http.ResponseWriter, r *http.Request) { log, _ := logger.GetLogger() log.Info("GET /play") - claims, err := auth.ValidateJWT(r.Header.Get("Authorization")) - - if err != nil { - JsonError(&w, err.Error()) + claims, ok := r.Context().Value("claims").(*auth.Claims) + if !ok { + JsonError(&w, "claims not found") return } db, _ := database.GetDb() var games []database.Game - result := db.Where("player1_id = ? OR player2_id = ?", claims.UserID, claims.UserID). + if result := db.Where("player1_id = ? OR player2_id = ?", claims.UserID, claims.UserID). Preload("Player1", func(db *gorm.DB) *gorm.DB { return db.Omit("Password") }). @@ -225,9 +206,7 @@ func AllPlay(w http.ResponseWriter, r *http.Request) { return db.Omit("Password") }). Order("updated_at DESC"). - Find(&games) - - if result.Error != nil { + Find(&games); result.Error != nil { JsonError(&w, result.Error.Error()) return } @@ -241,26 +220,23 @@ func GetGameId(w http.ResponseWriter, r *http.Request) { id := vars["id"] log.Info(fmt.Sprintf("GET /play/%s", id)) - claims, err := auth.ValidateJWT(r.Header.Get("Authorization")) - - if err != nil { - JsonError(&w, err.Error()) + claims, ok := r.Context().Value("claims").(*auth.Claims) + if !ok { + JsonError(&w, "claims not found") return } db, _ := database.GetDb() var game database.Game - result := db.Where("id = ? AND (player1_id = ? OR player2_id = ?)", id, claims.UserID, claims.UserID). + if result := db.Where("id = ? AND (player1_id = ? OR player2_id = ?)", id, claims.UserID, claims.UserID). Preload("Player1", func(db *gorm.DB) *gorm.DB { return db.Omit("Password") }). Preload("Player2", func(db *gorm.DB) *gorm.DB { return db.Omit("Password") }). - First(&game) - - if result.Error != nil { + First(&game); result.Error != nil { JsonError(&w, result.Error.Error()) return } @@ -274,10 +250,9 @@ func EndGame(w http.ResponseWriter, r *http.Request) { id := vars["id"] log.Info(fmt.Sprintf("POST /play/%s/end", id)) - claims, err := auth.ValidateJWT(r.Header.Get("Authorization")) - - if err != nil { - JsonError(&w, err.Error()) + claims, ok := r.Context().Value("claims").(*auth.Claims) + if !ok { + JsonError(&w, "claims not found") return } @@ -290,18 +265,15 @@ func EndGame(w http.ResponseWriter, r *http.Request) { return } - if err != nil { - JsonError(&w, err.Error()) - return - } - db, _ := database.GetDb() var game database.Game // FIXME: this is not secure - result := db.Where("id = ? AND (player1_id = ? OR player2_id = ?)", id, claims.UserID, claims.UserID).First(&game) - if result.Error != nil { + if result := db.Where( + "id = ? AND (player1_id = ? OR player2_id = ?)", + id, claims.UserID, claims.UserID, + ).First(&game); result.Error != nil { JsonError(&w, result.Error.Error()) return } @@ -313,13 +285,9 @@ func EndGame(w http.ResponseWriter, r *http.Request) { return } - result = db.Where("id = ?", game.ID). - Preload("Player1", func(db *gorm.DB) *gorm.DB { - return db.Omit("Password") - }). - Preload("Player2", func(db *gorm.DB) *gorm.DB { - return db.Omit("Password") - }). + result := db.Where("id = ?", game.ID). + Preload("Player1", auth.OmitPassword). + Preload("Player2", auth.OmitPassword). First(&game) if result.Error != nil { diff --git a/internal/api/middleware/middleware.go b/internal/api/middleware/middleware.go index 0334e78..d7c5a30 100644 --- a/internal/api/middleware/middleware.go +++ b/internal/api/middleware/middleware.go @@ -1,12 +1,16 @@ package middleware import ( + "context" "encoding/json" "net/http" "github.com/boozec/rahanna/internal/api/auth" ) +// AuthMiddleware ensures that the requester has passed the Authorization +// header with a valid JWY token. +// It passes the claims item via context func AuthMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { tokenString := r.Header.Get("Authorization") @@ -22,7 +26,7 @@ func AuthMiddleware(next http.Handler) http.Handler { return } - _, err := auth.ValidateJWT(tokenString) + claims, err := auth.ValidateJWT(tokenString) if err != nil { w.WriteHeader(http.StatusUnauthorized) @@ -31,6 +35,9 @@ func AuthMiddleware(next http.Handler) http.Handler { w.Write([]byte(payload)) return } - next.ServeHTTP(w, r) + + ctx := context.WithValue(r.Context(), "claims", claims) + + next.ServeHTTP(w, r.WithContext(ctx)) }) } diff --git a/internal/logger/logger.go b/internal/logger/logger.go index a5d0264..b7ecb86 100644 --- a/internal/logger/logger.go +++ b/internal/logger/logger.go @@ -2,6 +2,7 @@ package logger import ( "errors" + "os" "go.uber.org/zap" "go.uber.org/zap/zapcore" @@ -10,31 +11,44 @@ import ( var logger *zap.Logger = nil -func InitLogger(logFile string) *zap.Logger { +// Set up a new Zap logger. If `onlyFile` is true, set up the logger to work +// only on file, else prints on stdout +func InitLogger(logFile string, onlyFile bool) *zap.Logger { cfg := zap.NewProductionConfig() - cfg.OutputPaths = []string{logFile} - cfg.ErrorOutputPaths = []string{logFile} + cfg.EncoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder + + var core zapcore.Core - // Configure lumberjack for log rotation lumberjackLogger := &lumberjack.Logger{ Filename: logFile, - MaxSize: 100, // megabytes + MaxSize: 100, MaxBackups: 5, - MaxAge: 30, // days + MaxAge: 30, Compress: true, } - core := zapcore.NewCore( + fileCore := zapcore.NewCore( zapcore.NewJSONEncoder(cfg.EncoderConfig), - zapcore.AddSync(lumberjackLogger), // Log only to the file via lumberjack + zapcore.AddSync(lumberjackLogger), cfg.Level, ) - logger = zap.New(core) + if onlyFile { + core = fileCore + } else { + consoleCore := zapcore.NewCore( + zapcore.NewConsoleEncoder(cfg.EncoderConfig), + zapcore.Lock(os.Stdout), + cfg.Level, + ) + core = zapcore.NewTee(fileCore, consoleCore) + } + logger = zap.New(core) return logger } +// Return the global Zap logger after calling `InitLogger` method func GetLogger() (*zap.Logger, error) { if logger == nil { return nil, errors.New("You must call `InitLogger()` first.") diff --git a/internal/network/ip.go b/internal/network/ip.go index dcd15db..ec1e984 100644 --- a/internal/network/ip.go +++ b/internal/network/ip.go @@ -8,6 +8,7 @@ import ( "github.com/boozec/rahanna/internal/logger" ) +// Connect a DNS to get the address func GetOutboundIP() net.IP { log, _ := logger.GetLogger() conn, err := net.Dial("udp", "8.8.8.8:80") @@ -21,6 +22,7 @@ func GetOutboundIP() net.IP { return localAddr.IP } +// Returns a random available port on the node func GetRandomAvailablePort() (int, error) { for i := 0; i < 100; i += 1 { port := rand.Intn(65535-1024) + 1024 |