From 1f0d9ec8452f15c27cd33c4e3874454c35993743 Mon Sep 17 00:00:00 2001 From: Santo Cariotti Date: Tue, 8 Apr 2025 14:37:33 +0200 Subject: Use internal/pkg structure --- internal/api/auth/auth.go | 57 +++++++++ internal/api/auth/auth_test.go | 74 ++++++++++++ internal/api/database/database.go | 32 +++++ internal/api/database/models.go | 24 ++++ internal/api/handlers/handlers.go | 197 +++++++++++++++++++++++++++++++ internal/api/handlers/utils.go | 34 ++++++ internal/api/middleware/middleware.go | 36 ++++++ internal/network/ip.go | 33 ++++++ internal/network/network.go | 213 ++++++++++++++++++++++++++++++++++ internal/network/network_test.go | 52 +++++++++ internal/network/session.go | 23 ++++ 11 files changed, 775 insertions(+) create mode 100644 internal/api/auth/auth.go create mode 100644 internal/api/auth/auth_test.go create mode 100644 internal/api/database/database.go create mode 100644 internal/api/database/models.go create mode 100644 internal/api/handlers/handlers.go create mode 100644 internal/api/handlers/utils.go create mode 100644 internal/api/middleware/middleware.go create mode 100644 internal/network/ip.go create mode 100644 internal/network/network.go create mode 100644 internal/network/network_test.go create mode 100644 internal/network/session.go (limited to 'internal') diff --git a/internal/api/auth/auth.go b/internal/api/auth/auth.go new file mode 100644 index 0000000..b382beb --- /dev/null +++ b/internal/api/auth/auth.go @@ -0,0 +1,57 @@ +package auth + +import ( + "errors" + "os" + "strings" + "time" + + "github.com/golang-jwt/jwt/v5" +) + +var jwtKey = []byte(os.Getenv("JWT_SECRET")) + +type Claims struct { + UserID int `json:"user_id"` + jwt.RegisteredClaims +} + +func GenerateJWT(userID int) (string, error) { + expirationTime := time.Now().Add(5 * time.Hour) + claims := &Claims{ + UserID: userID, + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(expirationTime), + }, + } + + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + tokenString, err := token.SignedString(jwtKey) + if err != nil { + return "", err + } + return tokenString, nil +} + +func ValidateJWT(tokenString string) (*Claims, error) { + claims := &Claims{} + // A token has a form `Bearer ...` + tokenParts := strings.Split(tokenString, " ") + if len(tokenParts) != 2 { + return nil, errors.New("not valid JWT") + } + + token, err := jwt.ParseWithClaims(tokenParts[1], claims, func(token *jwt.Token) (interface{}, error) { + return jwtKey, nil + }) + + if err != nil { + return nil, err + } + + if !token.Valid { + return nil, err + } + + return claims, nil +} diff --git a/internal/api/auth/auth_test.go b/internal/api/auth/auth_test.go new file mode 100644 index 0000000..50b6c9b --- /dev/null +++ b/internal/api/auth/auth_test.go @@ -0,0 +1,74 @@ +package auth + +import ( + "os" + "testing" + "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/assert" +) + +func TestGenerateAndValidateJWT(t *testing.T) { + // Set up the JWT secret for the test. + os.Setenv("JWT_SECRET", "testsecret") + jwtKey = []byte(os.Getenv("JWT_SECRET")) + + userID := 123 + tokenString, err := GenerateJWT(userID) + assert.NoError(t, err) + assert.NotEmpty(t, tokenString) + + claims, err := ValidateJWT(tokenString) + assert.NoError(t, err) + assert.NotNil(t, claims) + assert.Equal(t, userID, claims.UserID) + assert.True(t, claims.ExpiresAt.After(time.Now())) +} + +func TestValidateJWT_InvalidToken(t *testing.T) { + os.Setenv("JWT_SECRET", "testsecret") + jwtKey = []byte(os.Getenv("JWT_SECRET")) + + _, err := ValidateJWT("invalid_token") + assert.Error(t, err) +} + +func TestValidateJWT_ExpiredToken(t *testing.T) { + os.Setenv("JWT_SECRET", "testsecret") + jwtKey = []byte(os.Getenv("JWT_SECRET")) + + // Create a token that has already expired. + claims := &Claims{ + UserID: 123, + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(-1 * time.Hour)), + }, + } + + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + tokenString, err := token.SignedString(jwtKey) + assert.NoError(t, err) + + _, err = ValidateJWT(tokenString) + assert.Error(t, err) +} + +func TestValidateJWT_WrongSecret(t *testing.T) { + os.Setenv("JWT_SECRET", "testsecret") + jwtKey = []byte(os.Getenv("JWT_SECRET")) + + userID := 123 + tokenString, err := GenerateJWT(userID) + assert.NoError(t, err) + + // Set a different secret for validation. + wrongKey := []byte("wrongsecret") + + claims := &Claims{} + _, err = jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (interface{}, error) { + return wrongKey, nil + }) + + assert.Error(t, err) +} diff --git a/internal/api/database/database.go b/internal/api/database/database.go new file mode 100644 index 0000000..4470c58 --- /dev/null +++ b/internal/api/database/database.go @@ -0,0 +1,32 @@ +package database + +import ( + "gorm.io/driver/postgres" + "gorm.io/gorm" + + "errors" +) + +// Global variable but private +var db *gorm.DB = nil + +// Init the database from a DSN string which must be a valid PostgreSQL dsn. +// Also, auto migrate all the models. +func InitDb(dsn string) (*gorm.DB, error) { + var err error + db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{}) + + if err == nil { + db.AutoMigrate(&User{}, &Game{}) + } + + return db, err +} + +// Return the instance or error if the config is not laoded yet +func GetDb() (*gorm.DB, error) { + if db == nil { + return nil, errors.New("You must call `InitDb()` first.") + } + return db, nil +} diff --git a/internal/api/database/models.go b/internal/api/database/models.go new file mode 100644 index 0000000..a6e76c5 --- /dev/null +++ b/internal/api/database/models.go @@ -0,0 +1,24 @@ +package database + +import "time" + +type User struct { + ID int `json:"id"` + Username string `json:"username"` + Password string `json:"password"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +type Game struct { + ID int `json:"id"` + Player1ID int `json:"-"` + Player1 User `gorm:"foreignKey:Player1ID" json:"player1"` + Player2ID *int `json:"-"` + Player2 *User `gorm:"foreignKey:Player2ID;null" json:"player2"` + Name string `json:"name"` + IP1 string `json:"ip1"` + IP2 string `json:"ip2"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} diff --git a/internal/api/handlers/handlers.go b/internal/api/handlers/handlers.go new file mode 100644 index 0000000..b448502 --- /dev/null +++ b/internal/api/handlers/handlers.go @@ -0,0 +1,197 @@ +package handlers + +import ( + "encoding/json" + "log/slog" + "net/http" + "time" + + "github.com/boozec/rahanna/internal/api/auth" + "github.com/boozec/rahanna/internal/api/database" + "github.com/boozec/rahanna/internal/network" + "gorm.io/gorm" +) + +type NewGameRequest struct { + IP string `json:"ip"` +} + +func RegisterUser(w http.ResponseWriter, r *http.Request) { + slog.Info("POST /auth/register") + var user database.User + err := json.NewDecoder(r.Body).Decode(&user) + if err != nil { + JsonError(&w, err.Error()) + return + } + + if len(user.Password) < 4 { + JsonError(&w, "password too short") + return + } + + var storedUser database.User + db, _ := database.GetDb() + result := db.Where("username = ?", user.Username).First(&storedUser) + + if result.Error == nil { + JsonError(&w, "user with this username already exists") + return + } + + hashedPassword, err := HashPassword(user.Password) + if err != nil { + JsonError(&w, err.Error()) + return + } + user.Password = string(hashedPassword) + + result = db.Create(&user) + if result.Error != nil { + JsonError(&w, result.Error.Error()) + return + } + + token, err := auth.GenerateJWT(user.ID) + if err != nil { + JsonError(&w, err.Error()) + return + } + + json.NewEncoder(w).Encode(map[string]string{"token": token}) +} + +func LoginUser(w http.ResponseWriter, r *http.Request) { + slog.Info("POST /auth/login") + var inputUser database.User + err := json.NewDecoder(r.Body).Decode(&inputUser) + if err != nil { + JsonError(&w, err.Error()) + return + } + + var storedUser database.User + + db, _ := database.GetDb() + result := db.Where("username = ?", inputUser.Username).First(&storedUser) + if result.Error != nil { + JsonError(&w, "invalid credentials") + return + } + + if err := CheckPasswordHash(storedUser.Password, inputUser.Password); err != nil { + JsonError(&w, "invalid credentials") + return + } + + token, err := auth.GenerateJWT(storedUser.ID) + if err != nil { + JsonError(&w, err.Error()) + return + } + + json.NewEncoder(w).Encode(map[string]string{"token": token}) +} + +func NewPlay(w http.ResponseWriter, r *http.Request) { + slog.Info("POST /play") + claims, err := auth.ValidateJWT(r.Header.Get("Authorization")) + + if err != nil { + JsonError(&w, err.Error()) + return + } + + var payload struct { + IP string `json:"ip"` + } + + if err := json.NewDecoder(r.Body).Decode(&payload); err != nil { + JsonError(&w, err.Error()) + return + } + + if err != nil { + JsonError(&w, err.Error()) + return + } + + db, _ := database.GetDb() + + name := network.NewSession() + play := database.Game{ + Player1ID: claims.UserID, + Player2ID: nil, + Name: name, + IP1: payload.IP, + IP2: "", + } + + result := db.Create(&play) + if result.Error != nil { + JsonError(&w, result.Error.Error()) + return + } + + json.NewEncoder(w).Encode(map[string]string{"name": name}) +} + +func EnterGame(w http.ResponseWriter, r *http.Request) { + slog.Info("POST /enter-game") + claims, err := auth.ValidateJWT(r.Header.Get("Authorization")) + + if err != nil { + JsonError(&w, err.Error()) + return + } + + var payload struct { + Name string `json:"name"` + IP string `json:"ip"` + } + + if err := json.NewDecoder(r.Body).Decode(&payload); err != nil { + JsonError(&w, err.Error()) + return + } + + if err != nil { + JsonError(&w, err.Error()) + return + } + + db, _ := database.GetDb() + + var play database.Game + + result := db.Where("name = ? AND player2_id IS NULL", payload.Name).First(&play) + if result.Error != nil { + JsonError(&w, result.Error.Error()) + return + } + + play.Player2ID = &claims.UserID + play.IP2 = payload.IP + play.UpdatedAt = time.Now() + + if err := db.Save(&play).Error; err != nil { + JsonError(&w, err.Error()) + return + } + + result = db.Where("id = ?", play.ID). + Preload("Player1", func(db *gorm.DB) *gorm.DB { + return db.Omit("Password") + }). + Preload("Player2", func(db *gorm.DB) *gorm.DB { + return db.Omit("Password") + }). + First(&play) + + if result.Error != nil { + JsonError(&w, result.Error.Error()) + return + } + + json.NewEncoder(w).Encode(play) +} diff --git a/internal/api/handlers/utils.go b/internal/api/handlers/utils.go new file mode 100644 index 0000000..d6cc0d6 --- /dev/null +++ b/internal/api/handlers/utils.go @@ -0,0 +1,34 @@ +package handlers + +import ( + "encoding/json" + "net/http" + + "golang.org/x/crypto/bcrypt" +) + +func HashPassword(password string) (string, error) { + bytes, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + return string(bytes), err +} + +func CheckPasswordHash(hash, password string) error { + return bcrypt.CompareHashAndPassword([]byte(hash), []byte(password)) +} + +// Set a JSON response with status code 400 +func JsonError(w *http.ResponseWriter, error string) { + payloadMap := map[string]string{"error": error} + + (*w).Header().Set("Content-Type", "application/json") + (*w).WriteHeader(http.StatusBadRequest) + + payload, err := json.Marshal(payloadMap) + + if err != nil { + (*w).WriteHeader(http.StatusBadGateway) + (*w).Write([]byte(err.Error())) + } else { + (*w).Write(payload) + } +} diff --git a/internal/api/middleware/middleware.go b/internal/api/middleware/middleware.go new file mode 100644 index 0000000..0334e78 --- /dev/null +++ b/internal/api/middleware/middleware.go @@ -0,0 +1,36 @@ +package middleware + +import ( + "encoding/json" + "net/http" + + "github.com/boozec/rahanna/internal/api/auth" +) + +func AuthMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + tokenString := r.Header.Get("Authorization") + + payloadMap := map[string]string{"error": "unauthorized"} + payload, _ := json.Marshal(payloadMap) + + if tokenString == "" { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusUnauthorized) + + w.Write([]byte(payload)) + return + } + + _, err := auth.ValidateJWT(tokenString) + if err != nil { + w.WriteHeader(http.StatusUnauthorized) + + payload, _ := json.Marshal(payloadMap) + + w.Write([]byte(payload)) + return + } + next.ServeHTTP(w, r) + }) +} diff --git a/internal/network/ip.go b/internal/network/ip.go new file mode 100644 index 0000000..0c6451e --- /dev/null +++ b/internal/network/ip.go @@ -0,0 +1,33 @@ +package network + +import ( + "fmt" + "log/slog" + "math/rand" + "net" +) + +func GetOutboundIP() net.IP { + conn, err := net.Dial("udp", "8.8.8.8:80") + if err != nil { + slog.Error("err", err) + } + defer conn.Close() + + localAddr := conn.LocalAddr().(*net.UDPAddr) + + return localAddr.IP +} + +func GetRandomAvailablePort() (int, error) { + for i := 0; i < 100; i += 1 { + port := rand.Intn(65535-1024) + 1024 + addr := fmt.Sprintf(":%d", port) + ln, err := net.Listen("tcp", addr) + if err == nil { + defer ln.Close() + return port, nil + } + } + return 0, fmt.Errorf("failed to find an available port after multiple attempts") +} diff --git a/internal/network/network.go b/internal/network/network.go new file mode 100644 index 0000000..8283993 --- /dev/null +++ b/internal/network/network.go @@ -0,0 +1,213 @@ +package network + +import ( + "bufio" + "encoding/json" + "fmt" + "net" + "sync" + "time" + + "go.uber.org/zap" +) + +var logger *zap.Logger + +// PeerInfo represents a peer's ID and IP. +type PeerInfo struct { + ID string `json:"id"` + IP string `json:"ip"` + Port int `json:"port"` +} + +// Message represents a structured message. +type Message struct { + Type string `json:"type"` + Payload []byte `json:"payload"` + Source PeerInfo `json:"source"` + Target PeerInfo `json:"target"` + Timestamp int64 `json:"timestamp"` +} + +type NetworkCallback func(msg Message) + +// TCPNetwork represents a full-duplex TCP peer. +type TCPNetwork struct { + localPeer PeerInfo + connections map[string]net.Conn + listener net.Listener + callbacks map[string]NetworkCallback + callbacksMu sync.RWMutex + isConnected bool + retryDelay time.Duration + sync.Mutex +} + +// initializes a TCP peer +func NewTCPNetwork(localID, localIP string, localPort int) *TCPNetwork { + n := &TCPNetwork{ + localPeer: PeerInfo{ID: localID, IP: localIP, Port: localPort}, + connections: make(map[string]net.Conn), + callbacks: make(map[string]NetworkCallback), + isConnected: false, + retryDelay: 2 * time.Second, + } + + go n.startServer() + + logger, _ = zap.NewProduction() + + return n +} + +// Add a new peer connection to the local peer +func (n *TCPNetwork) AddPeer(remoteID string, remoteIP string, remotePort int) { + go n.retryConnect(remoteID, remoteIP, remotePort) +} + +// startServer starts a TCP server to accept connections. +func (n *TCPNetwork) startServer() { + address := fmt.Sprintf("%s:%d", n.localPeer.IP, n.localPeer.Port) + listener, err := net.Listen("tcp", address) + if err != nil { + logger.Sugar().Errorf("failed to start server: %v", err) + } + n.listener = listener + logger.Sugar().Infof("server started on %s\n", address) + + for { + conn, err := listener.Accept() + if err != nil { + logger.Sugar().Errorf("failed to accept connection: %v\n", err) + continue + } + + remoteAddr := conn.RemoteAddr().String() + n.Lock() + n.connections[remoteAddr] = conn + n.Unlock() + n.isConnected = true + n.retryDelay = 2 * time.Second + + logger.Sugar().Infof("connected to remote peer %s\n", remoteAddr) + go n.listenForMessages(conn) + } +} + +// retryConnect attempts to connect to a remote peer. +func (n *TCPNetwork) retryConnect(remoteID, remoteIP string, remotePort int) { + for { + n.Lock() + _, exists := n.connections[remoteID] + n.Unlock() + + if exists { + time.Sleep(5 * time.Second) + continue + } + + address := fmt.Sprintf("%s:%d", remoteIP, remotePort) + conn, err := net.Dial("tcp", address) + + if err != nil { + logger.Sugar().Errorf("failed to connect to %s: %v. Retrying in %v...", remoteID, err, n.retryDelay) + time.Sleep(n.retryDelay) + if n.retryDelay < 30*time.Second { + n.retryDelay *= 2 + } + continue + } + + n.Lock() + n.connections[remoteID] = conn + n.Unlock() + logger.Sugar().Infof("successfully connected to peer %s!", remoteID) + + go n.listenForMessages(conn) + } +} + +// Send sends a message to a specified remote peer. +func (n *TCPNetwork) Send(remoteID, messageType string, payload []byte) error { + n.Lock() + conn, exists := n.connections[remoteID] + n.Unlock() + + if !exists { + return fmt.Errorf("not connected to peer %s", remoteID) + } + + msg := Message{ + Type: messageType, + Payload: payload, + Source: n.localPeer, + Target: PeerInfo{ID: remoteID}, + Timestamp: time.Now().Unix(), + } + + data, err := json.Marshal(msg) + if err != nil { + return fmt.Errorf("failed to marshal message: %v", err) + } + + _, err = conn.Write(append(data, '\n')) + if err != nil { + logger.Sugar().Errorf("failed to send message to %s: %v. Reconnecting...", remoteID, err) + n.Lock() + delete(n.connections, remoteID) + n.Unlock() + go n.retryConnect(remoteID, "", 0) + return fmt.Errorf("failed to send message: %v", err) + } + + return nil +} + +// RegisterHandler registers a callback for a message type. +func (n *TCPNetwork) RegisterHandler(messageType string, callback NetworkCallback) { + n.callbacksMu.Lock() + n.callbacks[messageType] = callback + n.callbacksMu.Unlock() +} + +// listenForMessages listens for incoming messages. +func (n *TCPNetwork) listenForMessages(conn net.Conn) { + reader := bufio.NewReader(conn) + + for { + data, err := reader.ReadBytes('\n') + if err != nil { + logger.Debug("connection lost. Reconnecting...") + n.Lock() + for id, c := range n.connections { + if c == conn { + delete(n.connections, id) + go n.retryConnect(id, "", 0) + break + } + } + n.Unlock() + return + } + + var message Message + if err := json.Unmarshal(data, &message); err != nil { + logger.Sugar().Errorf("failed to unmarshal message: %v\n", err) + continue + } + + n.callbacksMu.RLock() + callback, exists := n.callbacks[message.Type] + n.callbacksMu.RUnlock() + + if exists { + go callback(message) + } + } +} + +func (n *TCPNetwork) IsConnected() bool { + n.Lock() + defer n.Unlock() + return n.isConnected +} diff --git a/internal/network/network_test.go b/internal/network/network_test.go new file mode 100644 index 0000000..9dbc416 --- /dev/null +++ b/internal/network/network_test.go @@ -0,0 +1,52 @@ +package network + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +// TestPeerToPeerCommunication tests if two peers can communicate. +func TestPeerToPeerCommunication(t *testing.T) { + // Create a mock of the first peer (peer-1) + peer1IP := "127.0.0.1" + peer1Port := 9001 + peer1 := NewTCPNetwork("peer-1", peer1IP, peer1Port) + + // Create a mock of the second peer (peer-2) + peer2IP := "127.0.0.1" + peer2Port := 9002 + peer2 := NewTCPNetwork("peer-2", peer2IP, peer2Port) + + // Register a message handler on peer-2 to receive the message from peer-1 + peer2.RegisterHandler("chat", func(msg Message) { + assert.Equal(t, "peer-1", msg.Source.ID) + assert.Equal(t, "Hey from peer-1!", string(msg.Payload)) + }) + + // Start the first peer and add the second peer + go peer1.AddPeer("peer-2", peer2IP, peer2Port) + go peer2.AddPeer("peer-1", peer1IP, peer1Port) + + // Wait for the connections to be established + // You might need a little more time based on network delay and retry logic + time.Sleep(5 * time.Second) + + // Send a message from peer-1 to peer-2 + err := peer1.Send("peer-2", "chat", []byte("Hey from peer-1!")) + assert.NoError(t, err) + + // Allow some time for the message to be received and handled + time.Sleep(2 * time.Second) +} + +// TestSendFailure tests if sending a message fails when no connection exists. +func TestSendFailure(t *testing.T) { + peer1 := NewTCPNetwork("peer-1", "127.0.0.1", 9001) + _ = NewTCPNetwork("peer-2", "127.0.0.1", 9002) + + // Attempt to send a message without establishing a connection first + err := peer1.Send("peer-2", "chat", []byte("Message without connection")) + assert.Error(t, err, "Expected error when sending to a non-connected peer") +} diff --git a/internal/network/session.go b/internal/network/session.go new file mode 100644 index 0000000..a4f60aa --- /dev/null +++ b/internal/network/session.go @@ -0,0 +1,23 @@ +package network + +import ( + "math/rand" +) + +var adjectives = []string{ + "adamant", "adept", "adventurous", "arcadian", "auspicious", + "awesome", "blossoming", "brave", "charming", "chatty", + "circular", "considerate", "cubic", "curious", "delighted", +} + +var nouns = []string{ + "aardvark", "accordion", "apple", "apricot", "bee", + "brachiosaur", "cactus", "capsicum", "clarinet", "cowbell", + "crab", "cuckoo", "cymbal", "diplodocus", "donkey", +} + +func NewSession() string { + noun := nouns[rand.Intn(len(nouns))] + adjective := adjectives[rand.Intn(len(adjectives))] + return noun + "-" + adjective +} -- cgit v1.2.3-18-g5258