diff options
Diffstat (limited to 'pkg/p2p/network.go')
-rw-r--r-- | pkg/p2p/network.go | 47 |
1 files changed, 25 insertions, 22 deletions
diff --git a/pkg/p2p/network.go b/pkg/p2p/network.go index 13c99b0..a609832 100644 --- a/pkg/p2p/network.go +++ b/pkg/p2p/network.go @@ -40,11 +40,12 @@ func DefaultHandshake(conn net.Conn) error { // Network options to define on new `TCPNetwork` type TCPNetworkOpts struct { - ListenAddr string - RetryDelay time.Duration - HandshakeFn NetworkHandshakeFunc - OnReceiveFn NetworkMessageReceiveFunc - Logger *zap.Logger + ListenAddr string + RetryDelay time.Duration + HandshakeFn NetworkHandshakeFunc + FirstHandshakeFn NetworkHandshakeFunc + OnReceiveFn NetworkMessageReceiveFunc + Logger *zap.Logger } // PeerConnection holds the connection and address of a peer. @@ -58,10 +59,11 @@ type TCPNetwork struct { sync.Mutex TCPNetworkOpts - id NetworkID - listener net.Listener - connections map[NetworkID]PeerConnection - isClosed bool + id NetworkID + listener net.Listener + connections map[NetworkID]PeerConnection + isClosed bool + handshakesCount uint } // Initiliaze a new TCP network @@ -100,11 +102,10 @@ func (n *TCPNetwork) Close() error { // Add a new peer connection to the local peer func (n *TCPNetwork) AddPeer(remoteID NetworkID, addr string) { n.Lock() - if _, exists := n.connections[remoteID]; !exists { - n.connections[remoteID] = PeerConnection{Address: addr} - go n.retryConnect(remoteID, addr) - } + n.connections[remoteID] = PeerConnection{Address: addr} n.Unlock() + + go n.retryConnect(remoteID, addr) } // Send methods is used to send a message to a specified remote peer @@ -140,7 +141,6 @@ func (n *TCPNetwork) Send(remoteID NetworkID, messageType []byte, payload []byte if err != nil { n.Logger.Sugar().Errorf("failed to send message to %s: %v. Reconnecting...", remoteID, err) n.removeConnection(remoteID) - go n.retryConnect(remoteID, peerConn.Address) return fmt.Errorf("failed to send message: %v", err) } else { n.Logger.Sugar().Infof("sent message to '%s' (%s): %s", remoteID, peerConn.Address, string(message.Payload)) @@ -187,6 +187,7 @@ func (n *TCPNetwork) handleConnection(conn net.Conn) { remoteID := NetworkID(remoteAddr) n.Lock() + n.handshakesCount++ n.connections[remoteID] = PeerConnection{Conn: conn, Address: remoteAddr} n.Unlock() @@ -199,6 +200,16 @@ func (n *TCPNetwork) handleConnection(conn net.Conn) { } } + if n.FirstHandshakeFn != nil && n.handshakesCount == 1 { + if err := n.FirstHandshakeFn(conn); err != nil { + n.Logger.Sugar().Errorf("error on first handshake with %s: %v\n", remoteAddr, err) + conn.Close() + n.removeConnection(remoteID) + return + } + + } + n.Logger.Sugar().Infof("connected to remote peer %s (%s)\n", remoteID, remoteAddr) n.listenForMessages(conn, remoteID) @@ -228,14 +239,6 @@ func (n *TCPNetwork) listenForMessages(conn net.Conn, remoteID NetworkID) { n.Logger.Sugar().Warnf("error reading from connection %s: %v", remoteAddr, err) } - n.Lock() - peerConn, exists := n.connections[remoteID] - n.Unlock() - if exists { - go n.retryConnect(remoteID, peerConn.Address) - } else { - n.Logger.Sugar().Warnf("no address to reconnect to peer %s", remoteID) - } return } |