diff options
-rw-r--r-- | cmd/api/main.go | 3 | ||||
-rw-r--r-- | cmd/ui/main.go | 2 | ||||
-rw-r--r-- | internal/api/auth/auth.go | 24 | ||||
-rw-r--r-- | internal/api/auth/auth_test.go | 2 | ||||
-rw-r--r-- | internal/api/handlers/handlers.go | 104 | ||||
-rw-r--r-- | internal/api/middleware/middleware.go | 11 | ||||
-rw-r--r-- | internal/logger/logger.go | 32 | ||||
-rw-r--r-- | internal/network/ip.go | 2 | ||||
-rw-r--r-- | pkg/ui/multiplayer/multiplayer.go | 32 | ||||
-rw-r--r-- | pkg/ui/views/api.go | 4 | ||||
-rw-r--r-- | pkg/ui/views/game.go | 23 | ||||
-rw-r--r-- | pkg/ui/views/game_api.go | 11 | ||||
-rw-r--r-- | pkg/ui/views/game_keymap.go | 10 | ||||
-rw-r--r-- | pkg/ui/views/game_moves.go | 4 | ||||
-rw-r--r-- | pkg/ui/views/game_util.go | 2 | ||||
-rw-r--r-- | pkg/ui/views/play.go | 2 | ||||
-rw-r--r-- | pkg/ui/views/play_api.go | 6 |
17 files changed, 147 insertions, 127 deletions
diff --git a/cmd/api/main.go b/cmd/api/main.go index 4bd538a..6d6b354 100644 --- a/cmd/api/main.go +++ b/cmd/api/main.go @@ -14,7 +14,8 @@ import ( func main() { database.InitDb(os.Getenv("DATABASE_URL")) - log := logger.InitLogger("rahanna.log") + log := logger.InitLogger("rahanna.log", false) + r := mux.NewRouter() r.HandleFunc("/auth/register", handlers.RegisterUser).Methods(http.MethodPost) diff --git a/cmd/ui/main.go b/cmd/ui/main.go index 97f894a..930adf2 100644 --- a/cmd/ui/main.go +++ b/cmd/ui/main.go @@ -10,7 +10,7 @@ import ( func main() { views.ClearScreen() - _ = logger.InitLogger("rahanna-ui.log") + _ = logger.InitLogger("rahanna-ui.log", true) p := tea.NewProgram(views.NewRahannaModel(), tea.WithAltScreen()) diff --git a/internal/api/auth/auth.go b/internal/api/auth/auth.go index b382beb..966a09c 100644 --- a/internal/api/auth/auth.go +++ b/internal/api/auth/auth.go @@ -7,17 +7,26 @@ import ( "time" "github.com/golang-jwt/jwt/v5" + "gorm.io/gorm" ) +// Key used for JWT encryption/decryption var jwtKey = []byte(os.Getenv("JWT_SECRET")) +// Kind of JWT token +var TokenType = "Bearer" + +// Extends JWT Claims with the UserID field type Claims struct { UserID int `json:"user_id"` jwt.RegisteredClaims } +// Generate a JWT token from an userID. func GenerateJWT(userID int) (string, error) { - expirationTime := time.Now().Add(5 * time.Hour) + // Set expiration date for the token to 90 days + expirationTime := time.Now().Add(90 * 24 * time.Hour) + claims := &Claims{ UserID: userID, RegisteredClaims: jwt.RegisteredClaims{ @@ -30,17 +39,21 @@ func GenerateJWT(userID int) (string, error) { if err != nil { return "", err } - return tokenString, nil + return TokenType + " " + tokenString, nil } +// Validate a JWT token for a kind of time func ValidateJWT(tokenString string) (*Claims, error) { claims := &Claims{} - // A token has a form `Bearer ...` tokenParts := strings.Split(tokenString, " ") if len(tokenParts) != 2 { return nil, errors.New("not valid JWT") } + if tokenParts[0] != TokenType { + return nil, errors.New("not valid JWT type") + } + token, err := jwt.ParseWithClaims(tokenParts[1], claims, func(token *jwt.Token) (interface{}, error) { return jwtKey, nil }) @@ -55,3 +68,8 @@ func ValidateJWT(tokenString string) (*Claims, error) { return claims, nil } + +// Common omit password field for users +func OmitPassword(db *gorm.DB) *gorm.DB { + return db.Omit("Password") +} diff --git a/internal/api/auth/auth_test.go b/internal/api/auth/auth_test.go index 66bcc27..50b6c9b 100644 --- a/internal/api/auth/auth_test.go +++ b/internal/api/auth/auth_test.go @@ -19,8 +19,6 @@ func TestGenerateAndValidateJWT(t *testing.T) { assert.NoError(t, err) assert.NotEmpty(t, tokenString) - tokenString = "Bearer " + tokenString - claims, err := ValidateJWT(tokenString) assert.NoError(t, err) assert.NotNil(t, claims) diff --git a/internal/api/handlers/handlers.go b/internal/api/handlers/handlers.go index 41779c7..6d1b4e3 100644 --- a/internal/api/handlers/handlers.go +++ b/internal/api/handlers/handlers.go @@ -21,6 +21,7 @@ type NewGameRequest struct { func RegisterUser(w http.ResponseWriter, r *http.Request) { log, _ := logger.GetLogger() log.Info("POST /auth/register") + var user database.User err := json.NewDecoder(r.Body).Decode(&user) if err != nil { @@ -35,9 +36,7 @@ func RegisterUser(w http.ResponseWriter, r *http.Request) { var storedUser database.User db, _ := database.GetDb() - result := db.Where("username = ?", user.Username).First(&storedUser) - - if result.Error == nil { + if result := db.Where("username = ?", user.Username).First(&storedUser); result.Error == nil { JsonError(&w, "user with this username already exists") return } @@ -49,8 +48,7 @@ func RegisterUser(w http.ResponseWriter, r *http.Request) { } user.Password = string(hashedPassword) - result = db.Create(&user) - if result.Error != nil { + if result := db.Create(&user); result.Error != nil { JsonError(&w, result.Error.Error()) return } @@ -67,6 +65,7 @@ func RegisterUser(w http.ResponseWriter, r *http.Request) { func LoginUser(w http.ResponseWriter, r *http.Request) { log, _ := logger.GetLogger() log.Info("POST /auth/login") + var inputUser database.User err := json.NewDecoder(r.Body).Decode(&inputUser) if err != nil { @@ -77,8 +76,7 @@ func LoginUser(w http.ResponseWriter, r *http.Request) { var storedUser database.User db, _ := database.GetDb() - result := db.Where("username = ?", inputUser.Username).First(&storedUser) - if result.Error != nil { + if result := db.Where("username = ?", inputUser.Username).First(&storedUser); result.Error != nil { JsonError(&w, "invalid credentials") return } @@ -100,10 +98,10 @@ func LoginUser(w http.ResponseWriter, r *http.Request) { func NewPlay(w http.ResponseWriter, r *http.Request) { log, _ := logger.GetLogger() log.Info("POST /play") - claims, err := auth.ValidateJWT(r.Header.Get("Authorization")) - if err != nil { - JsonError(&w, err.Error()) + claims, ok := r.Context().Value("claims").(*auth.Claims) + if !ok { + JsonError(&w, "claims not found") return } @@ -116,11 +114,6 @@ func NewPlay(w http.ResponseWriter, r *http.Request) { return } - if err != nil { - JsonError(&w, err.Error()) - return - } - db, _ := database.GetDb() name := network.NewSession() @@ -133,8 +126,7 @@ func NewPlay(w http.ResponseWriter, r *http.Request) { Outcome: "*", } - result := db.Create(&play) - if result.Error != nil { + if result := db.Create(&play); result.Error != nil { JsonError(&w, result.Error.Error()) return } @@ -145,10 +137,10 @@ func NewPlay(w http.ResponseWriter, r *http.Request) { func EnterGame(w http.ResponseWriter, r *http.Request) { log, _ := logger.GetLogger() log.Info("POST /enter-game") - claims, err := auth.ValidateJWT(r.Header.Get("Authorization")) - if err != nil { - JsonError(&w, err.Error()) + claims, ok := r.Context().Value("claims").(*auth.Claims) + if !ok { + JsonError(&w, "claims not found") return } @@ -162,17 +154,11 @@ func EnterGame(w http.ResponseWriter, r *http.Request) { return } - if err != nil { - JsonError(&w, err.Error()) - return - } - db, _ := database.GetDb() var game database.Game - result := db.Where("name = ? AND player2_id IS NULL", payload.Name).First(&game) - if result.Error != nil { + if result := db.Where("name = ? AND player2_id IS NULL", payload.Name).First(&game); result.Error != nil { JsonError(&w, result.Error.Error()) return } @@ -186,13 +172,9 @@ func EnterGame(w http.ResponseWriter, r *http.Request) { return } - result = db.Where("id = ?", game.ID). - Preload("Player1", func(db *gorm.DB) *gorm.DB { - return db.Omit("Password") - }). - Preload("Player2", func(db *gorm.DB) *gorm.DB { - return db.Omit("Password") - }). + result := db.Where("id = ?", game.ID). + Preload("Player1", auth.OmitPassword). + Preload("Player2", auth.OmitPassword). First(&game) if result.Error != nil { @@ -207,17 +189,16 @@ func AllPlay(w http.ResponseWriter, r *http.Request) { log, _ := logger.GetLogger() log.Info("GET /play") - claims, err := auth.ValidateJWT(r.Header.Get("Authorization")) - - if err != nil { - JsonError(&w, err.Error()) + claims, ok := r.Context().Value("claims").(*auth.Claims) + if !ok { + JsonError(&w, "claims not found") return } db, _ := database.GetDb() var games []database.Game - result := db.Where("player1_id = ? OR player2_id = ?", claims.UserID, claims.UserID). + if result := db.Where("player1_id = ? OR player2_id = ?", claims.UserID, claims.UserID). Preload("Player1", func(db *gorm.DB) *gorm.DB { return db.Omit("Password") }). @@ -225,9 +206,7 @@ func AllPlay(w http.ResponseWriter, r *http.Request) { return db.Omit("Password") }). Order("updated_at DESC"). - Find(&games) - - if result.Error != nil { + Find(&games); result.Error != nil { JsonError(&w, result.Error.Error()) return } @@ -241,26 +220,23 @@ func GetGameId(w http.ResponseWriter, r *http.Request) { id := vars["id"] log.Info(fmt.Sprintf("GET /play/%s", id)) - claims, err := auth.ValidateJWT(r.Header.Get("Authorization")) - - if err != nil { - JsonError(&w, err.Error()) + claims, ok := r.Context().Value("claims").(*auth.Claims) + if !ok { + JsonError(&w, "claims not found") return } db, _ := database.GetDb() var game database.Game - result := db.Where("id = ? AND (player1_id = ? OR player2_id = ?)", id, claims.UserID, claims.UserID). + if result := db.Where("id = ? AND (player1_id = ? OR player2_id = ?)", id, claims.UserID, claims.UserID). Preload("Player1", func(db *gorm.DB) *gorm.DB { return db.Omit("Password") }). Preload("Player2", func(db *gorm.DB) *gorm.DB { return db.Omit("Password") }). - First(&game) - - if result.Error != nil { + First(&game); result.Error != nil { JsonError(&w, result.Error.Error()) return } @@ -274,10 +250,9 @@ func EndGame(w http.ResponseWriter, r *http.Request) { id := vars["id"] log.Info(fmt.Sprintf("POST /play/%s/end", id)) - claims, err := auth.ValidateJWT(r.Header.Get("Authorization")) - - if err != nil { - JsonError(&w, err.Error()) + claims, ok := r.Context().Value("claims").(*auth.Claims) + if !ok { + JsonError(&w, "claims not found") return } @@ -290,18 +265,15 @@ func EndGame(w http.ResponseWriter, r *http.Request) { return } - if err != nil { - JsonError(&w, err.Error()) - return - } - db, _ := database.GetDb() var game database.Game // FIXME: this is not secure - result := db.Where("id = ? AND (player1_id = ? OR player2_id = ?)", id, claims.UserID, claims.UserID).First(&game) - if result.Error != nil { + if result := db.Where( + "id = ? AND (player1_id = ? OR player2_id = ?)", + id, claims.UserID, claims.UserID, + ).First(&game); result.Error != nil { JsonError(&w, result.Error.Error()) return } @@ -313,13 +285,9 @@ func EndGame(w http.ResponseWriter, r *http.Request) { return } - result = db.Where("id = ?", game.ID). - Preload("Player1", func(db *gorm.DB) *gorm.DB { - return db.Omit("Password") - }). - Preload("Player2", func(db *gorm.DB) *gorm.DB { - return db.Omit("Password") - }). + result := db.Where("id = ?", game.ID). + Preload("Player1", auth.OmitPassword). + Preload("Player2", auth.OmitPassword). First(&game) if result.Error != nil { diff --git a/internal/api/middleware/middleware.go b/internal/api/middleware/middleware.go index 0334e78..d7c5a30 100644 --- a/internal/api/middleware/middleware.go +++ b/internal/api/middleware/middleware.go @@ -1,12 +1,16 @@ package middleware import ( + "context" "encoding/json" "net/http" "github.com/boozec/rahanna/internal/api/auth" ) +// AuthMiddleware ensures that the requester has passed the Authorization +// header with a valid JWY token. +// It passes the claims item via context func AuthMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { tokenString := r.Header.Get("Authorization") @@ -22,7 +26,7 @@ func AuthMiddleware(next http.Handler) http.Handler { return } - _, err := auth.ValidateJWT(tokenString) + claims, err := auth.ValidateJWT(tokenString) if err != nil { w.WriteHeader(http.StatusUnauthorized) @@ -31,6 +35,9 @@ func AuthMiddleware(next http.Handler) http.Handler { w.Write([]byte(payload)) return } - next.ServeHTTP(w, r) + + ctx := context.WithValue(r.Context(), "claims", claims) + + next.ServeHTTP(w, r.WithContext(ctx)) }) } diff --git a/internal/logger/logger.go b/internal/logger/logger.go index a5d0264..b7ecb86 100644 --- a/internal/logger/logger.go +++ b/internal/logger/logger.go @@ -2,6 +2,7 @@ package logger import ( "errors" + "os" "go.uber.org/zap" "go.uber.org/zap/zapcore" @@ -10,31 +11,44 @@ import ( var logger *zap.Logger = nil -func InitLogger(logFile string) *zap.Logger { +// Set up a new Zap logger. If `onlyFile` is true, set up the logger to work +// only on file, else prints on stdout +func InitLogger(logFile string, onlyFile bool) *zap.Logger { cfg := zap.NewProductionConfig() - cfg.OutputPaths = []string{logFile} - cfg.ErrorOutputPaths = []string{logFile} + cfg.EncoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder + + var core zapcore.Core - // Configure lumberjack for log rotation lumberjackLogger := &lumberjack.Logger{ Filename: logFile, - MaxSize: 100, // megabytes + MaxSize: 100, MaxBackups: 5, - MaxAge: 30, // days + MaxAge: 30, Compress: true, } - core := zapcore.NewCore( + fileCore := zapcore.NewCore( zapcore.NewJSONEncoder(cfg.EncoderConfig), - zapcore.AddSync(lumberjackLogger), // Log only to the file via lumberjack + zapcore.AddSync(lumberjackLogger), cfg.Level, ) - logger = zap.New(core) + if onlyFile { + core = fileCore + } else { + consoleCore := zapcore.NewCore( + zapcore.NewConsoleEncoder(cfg.EncoderConfig), + zapcore.Lock(os.Stdout), + cfg.Level, + ) + core = zapcore.NewTee(fileCore, consoleCore) + } + logger = zap.New(core) return logger } +// Return the global Zap logger after calling `InitLogger` method func GetLogger() (*zap.Logger, error) { if logger == nil { return nil, errors.New("You must call `InitLogger()` first.") diff --git a/internal/network/ip.go b/internal/network/ip.go index dcd15db..ec1e984 100644 --- a/internal/network/ip.go +++ b/internal/network/ip.go @@ -8,6 +8,7 @@ import ( "github.com/boozec/rahanna/internal/logger" ) +// Connect a DNS to get the address func GetOutboundIP() net.IP { log, _ := logger.GetLogger() conn, err := net.Dial("udp", "8.8.8.8:80") @@ -21,6 +22,7 @@ func GetOutboundIP() net.IP { return localAddr.IP } +// Returns a random available port on the node func GetRandomAvailablePort() (int, error) { for i := 0; i < 100; i += 1 { port := rand.Intn(65535-1024) + 1024 diff --git a/pkg/ui/multiplayer/multiplayer.go b/pkg/ui/multiplayer/multiplayer.go index 1680035..c9fc4b2 100644 --- a/pkg/ui/multiplayer/multiplayer.go +++ b/pkg/ui/multiplayer/multiplayer.go @@ -8,10 +8,12 @@ import ( ) type GameNetwork struct { - Server *network.TCPNetwork - Peer string + server *network.TCPNetwork + me network.NetworkID + peer network.NetworkID } +// Wrapper to a `TCPNetwork` func NewGameNetwork(localID string, address string, onHandshake network.NetworkHandshakeFunc, logger *zap.Logger) *GameNetwork { opts := network.TCPNetworkOpts{ ListenAddr: address, @@ -20,9 +22,29 @@ func NewGameNetwork(localID string, address string, onHandshake network.NetworkH Logger: logger, } server := network.NewTCPNetwork(network.NetworkID(localID), opts) - peer := "" return &GameNetwork{ - Server: server, - Peer: peer, + server: server, + me: network.NetworkID(localID), } } + +func (n *GameNetwork) Peer() network.NetworkID { + return n.peer +} + +func (n *GameNetwork) Me() network.NetworkID { + return n.me +} + +func (n *GameNetwork) Send(payload []byte) error { + return n.server.Send(n.peer, payload) +} + +func (n *GameNetwork) AddPeer(remoteID network.NetworkID, addr string) { + n.peer = remoteID + n.server.AddPeer(remoteID, addr) +} + +func (n *GameNetwork) AddReceiveFunction(f network.NetworkMessageReceiveFunc) { + n.server.OnReceiveFn = f +} diff --git a/pkg/ui/views/api.go b/pkg/ui/views/api.go index ba202c8..14f4cf2 100644 --- a/pkg/ui/views/api.go +++ b/pkg/ui/views/api.go @@ -38,7 +38,7 @@ func getUserID() (int, error) { return -1, err } - claims, err := auth.ValidateJWT("Bearer " + token) + claims, err := auth.ValidateJWT(token) if err != nil { return -1, err } @@ -55,7 +55,7 @@ func sendAPIRequest(method, url string, payload []byte, authorization string) (* } req.Header.Set("Content-Type", "application/json") - req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", authorization)) + req.Header.Add("Authorization", authorization) client := &http.Client{} return client.Do(req) diff --git a/pkg/ui/views/game.go b/pkg/ui/views/game.go index 7075f8b..a97d068 100644 --- a/pkg/ui/views/game.go +++ b/pkg/ui/views/game.go @@ -5,7 +5,6 @@ import ( "strings" "github.com/boozec/rahanna/internal/api/database" - "github.com/boozec/rahanna/internal/network" "github.com/boozec/rahanna/pkg/ui/multiplayer" "github.com/charmbracelet/bubbles/list" "github.com/charmbracelet/bubbles/textinput" @@ -25,7 +24,6 @@ type GameModel struct { keys gameKeyMap // Game state - peer string currentGameID int game *database.Game network *multiplayer.GameNetwork @@ -36,7 +34,7 @@ type GameModel struct { } // NewGameModel creates a new GameModel. -func NewGameModel(width, height int, peer string, currentGameID int, network *multiplayer.GameNetwork) GameModel { +func NewGameModel(width, height int, currentGameID int, network *multiplayer.GameNetwork) GameModel { listDelegate := list.NewDefaultDelegate() listDelegate.ShowDescription = false listDelegate.Styles.SelectedTitle = lipgloss.NewStyle(). @@ -55,7 +53,6 @@ func NewGameModel(width, height int, peer string, currentGameID int, network *mu width: width, height: height, keys: defaultGameKeyMap, - peer: peer, currentGameID: currentGameID, network: network, chessGame: chess.NewGame(chess.UseNotation(chess.UCINotation{})), @@ -97,10 +94,10 @@ func (m GameModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { cmds = append(cmds, cmd, m.updateMovesListCmd()) case EndGameMsg: if msg.abandoned { - if m.peer == "peer-2" { - m.game.Outcome = "1-0" + if m.network.Me() == "peer-1" { + m.game.Outcome = string(chess.WhiteWon) } else { - m.game.Outcome = "0-1" + m.game.Outcome = string(chess.BlackWon) } m, cmd = m.handleDatabaseGameMsg(*m.game) cmds = append(cmds, cmd) @@ -123,7 +120,7 @@ func (m GameModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { m.err = err } else { m.turn++ - m.network.Server.Send(network.NetworkID(m.peer), []byte(moveStr)) + m.network.Send([]byte(moveStr)) m.err = nil } cmds = append(cmds, m.getMoves(), m.updateMovesListCmd()) @@ -172,17 +169,17 @@ func (m GameModel) View() string { } else { var outcome string switch m.game.Outcome { - case "1-0": + case string(chess.WhiteWon): outcome = "White won" - if m.peer == "peer-2" { + if m.network.Me() == "peer-1" { outcome += " (YOU)" } - case "0-1": + case string(chess.BlackWon): outcome = "Black won" - if m.peer == "peer-1" { + if m.network.Me() == "peer-2" { outcome += " (YOU)" } - case "1/2-1/2": + case string(chess.Draw): outcome = "Draw" default: outcome = "NoOutcome" diff --git a/pkg/ui/views/game_api.go b/pkg/ui/views/game_api.go index 34ba1f3..485df41 100644 --- a/pkg/ui/views/game_api.go +++ b/pkg/ui/views/game_api.go @@ -12,11 +12,6 @@ import ( func (m GameModel) handleDatabaseGameMsg(msg database.Game) (GameModel, tea.Cmd) { m.game = &msg - if m.peer == "peer-2" { - m.network.Peer = msg.IP2 - } else { - m.network.Peer = msg.IP1 - } var cmd tea.Cmd @@ -52,15 +47,15 @@ func (m *GameModel) getGame() tea.Cmd { } // Establish peer connection - if m.peer == "peer-2" { + if m.network.Me() == "peer-1" { if game.IP2 != "" { remote := game.IP2 - go m.network.Server.AddPeer("peer-2", remote) + go m.network.AddPeer("peer-2", remote) } } else { if game.IP1 != "" { remote := game.IP1 - go m.network.Server.AddPeer("peer-1", remote) + go m.network.AddPeer("peer-1", remote) } } diff --git a/pkg/ui/views/game_keymap.go b/pkg/ui/views/game_keymap.go index 29881c8..5c65a57 100644 --- a/pkg/ui/views/game_keymap.go +++ b/pkg/ui/views/game_keymap.go @@ -3,10 +3,10 @@ package views import ( "fmt" - "github.com/boozec/rahanna/internal/network" "github.com/charmbracelet/bubbles/key" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" + "github.com/notnil/chess" ) // gameKeyMap defines the key bindings for the game view. @@ -36,13 +36,13 @@ func (m GameModel) handleKeyMsg(msg tea.KeyMsg) (GameModel, tea.Cmd) { switch { case key.Matches(msg, m.keys.Abandon): var outcome string - if m.peer == "peer-2" { - outcome = "0-1" + if m.network.Me() == "peer-1" { + outcome = string(chess.BlackWon) } else { - outcome = "1-0" + outcome = string(chess.WhiteWon) } - m.network.Server.Send(network.NetworkID(m.peer), []byte("🏳️")) + m.network.Send([]byte("🏳️")) return m, m.endGame(outcome) case key.Matches(msg, m.keys.Quit): return m, SwitchModelCmd(NewPlayModel(m.width, m.height)) diff --git a/pkg/ui/views/game_moves.go b/pkg/ui/views/game_moves.go index 4ce3796..eeee9e1 100644 --- a/pkg/ui/views/game_moves.go +++ b/pkg/ui/views/game_moves.go @@ -24,10 +24,10 @@ func (i item) Description() string { return "" } func (i item) FilterValue() string { return i.title } func (m *GameModel) getMoves() tea.Cmd { - m.network.Server.OnReceiveFn = func(msg network.Message) { + m.network.AddReceiveFunction(func(msg network.Message) { payload := string(msg.Payload) m.incomingMoves <- payload - } + }) return func() tea.Msg { move := <-m.incomingMoves diff --git a/pkg/ui/views/game_util.go b/pkg/ui/views/game_util.go index 7d83eee..eb904c5 100644 --- a/pkg/ui/views/game_util.go +++ b/pkg/ui/views/game_util.go @@ -24,5 +24,5 @@ func (m GameModel) buildWindowContent(content string, formWidth int) string { } func (m GameModel) isMyTurn() bool { - return m.turn%2 == 0 && m.peer == "peer-2" || m.turn%2 == 1 && m.peer == "peer-1" + return m.turn%2 == 0 && m.network.Me() == "peer-1" || m.turn%2 == 1 && m.network.Me() == "peer-2" } diff --git a/pkg/ui/views/play.go b/pkg/ui/views/play.go index ed5f1e1..20e2ebb 100644 --- a/pkg/ui/views/play.go +++ b/pkg/ui/views/play.go @@ -89,7 +89,7 @@ func (m PlayModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { m.userID, m.err = getUserID() return m.handleGamesResponse(msg) case StartGameMsg: - return m, SwitchModelCmd(NewGameModel(m.width, m.height+1, "peer-2", m.currentGameId, m.network)) + return m, SwitchModelCmd(NewGameModel(m.width, m.height+1, m.currentGameId, m.network)) case error: return m.handleError(msg) } diff --git a/pkg/ui/views/play_api.go b/pkg/ui/views/play_api.go index 7119ffa..40f26a8 100644 --- a/pkg/ui/views/play_api.go +++ b/pkg/ui/views/play_api.go @@ -68,11 +68,9 @@ func (m *PlayModel) handleGameResponse(msg database.Game) (tea.Model, tea.Cmd) { localPort, _ := strconv.ParseInt(ip[1], 10, 32) logger, _ := logger.GetLogger() - network := multiplayer.NewGameNetwork("peer-2", fmt.Sprintf("%s:%d", localIP, localPort), func() error { - return nil - }, logger) + network := multiplayer.NewGameNetwork("peer-2", fmt.Sprintf("%s:%d", localIP, localPort), network.DefaultHandshake, logger) - return m, SwitchModelCmd(NewGameModel(m.width, m.height+1, "peer-1", m.game.ID, network)) + return m, SwitchModelCmd(NewGameModel(m.width, m.height+1, m.game.ID, network)) } return m, nil } |