Commit 293d84d7 authored by Anders Jensen Løvig's avatar Anders Jensen Løvig
Browse files

Refactor network message handlers

parent 1782e909
Pipeline #21738 passed with stages
in 1 minute and 1 second
......@@ -8,7 +8,6 @@ import (
"math/big"
"os"
"strings"
"sync"
"time"
"github.com/jessevdk/go-flags"
......@@ -19,8 +18,6 @@ var opts struct {
}
var server struct {
sync.Mutex
// Server enables this authority to listen for incomming connections.
Server *network.Server
// The election being held
......@@ -42,18 +39,16 @@ func init() {
os.Exit(0)
}
server.Connections = make(map[string]*network.Conn)
server.Mutex = sync.Mutex{}
server.Closed = make(chan bool, 1)
server.Done = make(chan bool, 1)
// Init configuration
// Load configuration to get election parameters and server list.
server.Config, err = network.LoadConfig()
if err != nil {
log.Fatal("fatal:", err)
}
// Find our current port and server id
portStr := fmt.Sprintf(":%d", opts.Port)
var serialNumber *big.Int = nil
for address, sn := range server.Config.Servers {
......@@ -67,30 +62,11 @@ func init() {
log.Fatal("fatal: server is not on server list")
}
// Initialize server using correct certificates
server.Config.CertPath = fmt.Sprintf("config/cert/server/s%d.crt", serialNumber)
server.Config.KeyPath = fmt.Sprintf("config/cert/server/s%d.key", serialNumber)
// Initialize election
servers := len(server.Config.Servers)
required := (servers / 2) + 1
config := &election.Config{
ServerID: election.UniqueID(serialNumber.String()),
Voters: 10,
Servers: servers,
RequiredServers: required,
Deadline: time.Now().Add(30 * time.Second),
CloseCallback: closeCallback,
TallyCallback: tallyCallback,
ResultCallback: resultCallback,
}
server.Election = election.NewElection(config)
// Initialize server
server.Server = network.NewServer(server.Config)
}
var clientHandlers = map[string]network.MessageHandler{
election.MsgBallot: func(conn *network.Conn, data network.MessageData) *network.Message {
server.Server.OnClientMessage(election.MsgBallot, func(conn *network.Conn, data network.MessageData) *network.Message {
var ballot election.Ballot
err := data.Unmarshal(&ballot)
if err != nil {
......@@ -99,14 +75,12 @@ var clientHandlers = map[string]network.MessageHandler{
go server.Election.HandleBallot(conn.SerialNumber(), &ballot)
}
return network.NewMessage(election.MsgBallotAck, nil)
},
}
var serverHandlers = map[string]network.MessageHandler{
election.MsgClosing: func(conn *network.Conn, data network.MessageData) *network.Message {
})
server.Server.OnServerMessage(election.MsgClosing, func(conn *network.Conn, data network.MessageData) *network.Message {
server.Election.HandleClosing(conn.SerialNumber())
return nil
},
election.MsgTally: func(conn *network.Conn, data network.MessageData) *network.Message {
})
server.Server.OnServerMessage(election.MsgTally, func(conn *network.Conn, data network.MessageData) *network.Message {
var tally election.Tally
err := data.Unmarshal(&tally)
if err != nil {
......@@ -115,71 +89,80 @@ var serverHandlers = map[string]network.MessageHandler{
go server.Election.HandleTally(conn.SerialNumber(), &tally)
}
return nil
},
}
func main() {
err := server.Server.Start(opts.Port, func(connType network.ConnType) map[string]network.MessageHandler {
if connType == network.ConnTypeClient {
return clientHandlers
} else {
return serverHandlers
}
})
if err != nil {
log.Fatal("Main: fatal:", err)
// Initialize election
servers := len(server.Config.Servers)
required := (servers / 2) + 1
config := &election.Config{
ServerID: election.UniqueID(serialNumber.String()),
Voters: 10,
Servers: servers,
RequiredServers: required,
Deadline: time.Now().Add(30 * time.Second),
CloseCallback: closeCallback,
TallyCallback: tallyCallback,
ResultCallback: resultCallback,
}
err = server.Election.Start()
server.Election = election.NewElection(config)
}
func startServer() {
err := server.Server.Start(opts.Port)
if err != nil {
log.Fatal("Main: fatal:", err)
}
time.Sleep(1 * time.Second) // Grace period
<-server.Done // Closed
<-server.Done // Result
server.Server.Close()
}
func closeCallback(reason election.Reason) {
func conenctServers() {
connections := 0
msgClosed := network.NewMessage(election.MsgClosing, nil)
// Connect to other servers
for address := range server.Config.Servers {
if address != server.Address {
conn, err := server.Server.Connect(address, serverHandlers)
_, err := server.Server.Connect(address)
if err != nil {
log.Println("Main: error:", err)
continue
}
server.Connections[address] = conn
connections++
err = conn.WriteMessage(msgClosed)
if err != nil {
log.Println("Main: error:", err)
}
}
}
if connections == 0 {
log.Fatal("Main: fatal: could not connect to servers")
}
}
func startElection() {
err := server.Election.Start()
if err != nil {
log.Fatal("Main: fatal:", err)
}
}
func main() {
startServer()
time.Sleep(2 * time.Second) // Give servers time to start before connecting
conenctServers()
time.Sleep(1 * time.Second) // To let servers connect before starting election
startElection()
<-server.Done // Closed
<-server.Done // Result
time.Sleep(1 * time.Second) // Grace period
server.Server.Close()
}
func closeCallback(reason election.Reason) {
msgClosed := network.NewMessage(election.MsgClosing, nil)
server.Server.Broadcast(msgClosed)
server.Done <- true
}
func tallyCallback(tally *election.Tally) {
msgTally := network.NewMessage(election.MsgTally, tally)
// Connect to other servers
for _, conn := range server.Connections {
err := conn.WriteMessage(msgTally)
if err != nil {
log.Println("Main: error:", err)
}
}
server.Server.Broadcast(msgTally)
}
func resultCallback(result *election.Result) {
......
......@@ -30,8 +30,11 @@ var voter struct {
Config *network.Config
Client *network.Client
Acks map[string]bool
Done chan bool
ConnectionCount int
RecievedCount int
Acks map[string]bool
Done chan bool
}
func init() {
......@@ -49,7 +52,21 @@ func init() {
config.KeyPath = opts.Key
// Init client
voter.RecievedCount = 0
voter.Client = network.NewClient(config)
voter.Client.OnMessage(election.MsgBallotAck, func(conn *network.Conn, data network.MessageData) *network.Message {
voter.Lock()
defer voter.Unlock()
voter.Acks[conn.RemoteAddr().String()] = true
conn.Close()
voter.RecievedCount++
if voter.RecievedCount == voter.ConnectionCount {
voter.Done <- true
}
return nil
})
voter.Mutex = sync.Mutex{}
voter.Acks = make(map[string]bool)
......@@ -124,27 +141,10 @@ func inputVote() int64 {
func sendBallots(ballots map[election.UniqueID]*election.Ballot) {
voter.Lock()
defer voter.Unlock()
connectionCount := 0
receivedCount := 0
i := 0
for address := range voter.Config.Servers {
conn, err := voter.Client.Connect(address,
map[string]network.MessageHandler{
election.MsgBallotAck: func(conn *network.Conn, data network.MessageData) *network.Message {
voter.Lock()
defer voter.Unlock()
voter.Acks[address] = true
conn.Close()
receivedCount++
if receivedCount == connectionCount {
voter.Done <- true
}
return nil
},
})
conn, err := voter.Client.Connect(address)
if err != nil {
log.Println("Main: error:", err)
continue
......@@ -157,12 +157,12 @@ func sendBallots(ballots map[election.UniqueID]*election.Ballot) {
log.Println("Main: error:", err)
continue
}
voter.Acks[address] = false
connectionCount++
voter.Acks[conn.RemoteAddr().String()] = false
voter.ConnectionCount++
i++
}
if connectionCount == 0 {
if voter.ConnectionCount == 0 {
log.Fatal("fatal: could not connect to servers")
}
}
......@@ -25,25 +25,27 @@ import (
// be closed using client.CloseConn(conn), where conn is a connection returned
// by client.Connect(address).
type Client struct {
sync.Mutex
// Map of current connection to the status of the connection.
// All connections returned by client.Connect(address) is put into the map.
// The boolean value indicates if the connection is closing. When all
// go-rountines associated with a connection has terminated the connection
// will be removed from connections.
connections map[*Conn]bool
mutex sync.Mutex
config *Config
tlsConfig *tls.Config
verifyOptions x509.VerifyOptions
handlers map[string]MessageHandler
}
func NewClient(config *Config) *Client {
tlsConfig := createClientConfig(config.RootPath, config.CertPath, config.KeyPath)
logName = "Client"
return &Client{
Mutex: sync.Mutex{},
connections: make(map[*Conn]bool),
mutex: sync.Mutex{},
config: config,
tlsConfig: tlsConfig,
verifyOptions: x509.VerifyOptions{
......@@ -53,6 +55,7 @@ func NewClient(config *Config) *Client {
KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
MaxConstraintComparisions: 0,
},
handlers: make(map[string]MessageHandler),
}
}
......@@ -60,9 +63,9 @@ func NewClient(config *Config) *Client {
// with two-way handshakes to establish a secure channel.
// The server certificate is checked for a serial number, which is known from
// the configuration.
func (client *Client) Connect(address string, handlers map[string]MessageHandler) (*Conn, error) {
client.mutex.Lock()
defer client.mutex.Unlock()
func (client *Client) Connect(address string) (*Conn, error) {
client.Lock()
defer client.Unlock()
tlsConn, err := tls.Dial("tcp", address, client.tlsConfig)
if err != nil {
......@@ -77,7 +80,7 @@ func (client *Client) Connect(address string, handlers map[string]MessageHandler
}
// Register the connection
conn := newConn(ConnTypeServer, tlsConn, serialNumber, handlers)
conn := newConn(ConnTypeServer, tlsConn, serialNumber, client.handlers)
client.connections[conn] = true
go client.handleConnection(conn)
......@@ -109,8 +112,8 @@ func (client *Client) checkConnection(address string, conn *tls.Conn) (election.
// Close all connections. An error is returned if some connections failed.
func (client *Client) Close() error {
client.mutex.Lock()
defer client.mutex.Unlock()
client.Lock()
defer client.Unlock()
for conn, active := range client.connections {
if active {
......@@ -126,8 +129,8 @@ func (client *Client) Close() error {
// CloseConnection closes the given connection
func (client *Client) CloseConnection(conn *Conn) error {
client.mutex.Lock()
defer client.mutex.Unlock()
client.Lock()
defer client.Unlock()
return conn.Close()
}
......@@ -136,7 +139,29 @@ func (client *Client) handleConnection(conn *Conn) {
messageLoop(conn)
// Remove unregister connection because it closes when handleConnection returns
client.mutex.Lock()
defer client.mutex.Unlock()
client.Lock()
defer client.Unlock()
delete(client.connections, conn)
}
// Broadcast message to all current connections.
func (client *Client) Broadcast(message *Message) {
client.Lock()
defer client.Unlock()
log.Printf("%s: broadcasting message '%s'\n", logName, message.Command)
for conn := range client.connections {
err := conn.WriteMessage(message)
if err != nil {
log.Printf("%s: error: %s\n", logName, err)
}
}
}
// Registers a message handler for the given message.
func (client *Client) OnMessage(message string, handler MessageHandler) {
client.Lock()
defer client.Unlock()
client.handlers[message] = handler
}
......@@ -20,20 +20,22 @@ import (
// certificate must have extended key usage for client authorization.
type Server struct {
Client
sync.Mutex
/// Status of each connection.
// If true, then the connection is active na dif false,
// then the connection is closing.
connections map[*Conn]bool
// The connection listener
listener net.Listener
mutex sync.Mutex
// True if server.Close() has been called
closing bool
config *Config
tlsConfig *tls.Config
// Message handlers for incomming server and client connections.
serverHandlers map[string]MessageHandler
clientHandlers map[string]MessageHandler
}
// NewServer returns a new server using the provided message handlers
......@@ -42,8 +44,8 @@ func NewServer(config *Config) *Server {
logName = "Server"
return &Server{
Client: Client{
Mutex: sync.Mutex{},
connections: make(map[*Conn]bool),
mutex: sync.Mutex{},
config: config,
tlsConfig: tlsConfig,
verifyOptions: x509.VerifyOptions{
......@@ -54,19 +56,21 @@ func NewServer(config *Config) *Server {
MaxConstraintComparisions: 0,
},
},
connections: make(map[*Conn]bool),
mutex: sync.Mutex{},
closing: false,
config: config,
tlsConfig: tlsConfig,
Mutex: sync.Mutex{},
connections: make(map[*Conn]bool),
closing: false,
config: config,
tlsConfig: tlsConfig,
serverHandlers: make(map[string]MessageHandler),
clientHandlers: make(map[string]MessageHandler),
}
}
// Start the server. The server listens on the specified port.
// If the port is 0, then it is assigned a port by the OS.
func (server *Server) Start(port int, acceptCallback func(ConnType) map[string]MessageHandler) error {
server.mutex.Lock()
defer server.mutex.Unlock()
func (server *Server) Start(port int) error {
server.Lock()
defer server.Unlock()
listenAddr := ":"
if port != 0 {
......@@ -91,15 +95,15 @@ func (server *Server) Start(port int, acceptCallback func(ConnType) map[string]M
}
log.Printf("%s: listening on address: %s\n", logName, addr+":"+portStr)
go server.listen(listener, acceptCallback)
go server.listen(listener)
return nil
}
// Close the server listener and all connections. If an error occured
// some connections might not be closed.
func (server *Server) Close() error {
server.mutex.Lock()
defer server.mutex.Unlock()
server.Lock()
defer server.Unlock()
err := server.Client.Close()
if err != nil {
......@@ -123,7 +127,7 @@ func (server *Server) Close() error {
return nil
}
func (server *Server) listen(listener net.Listener, acceptCallback func(ConnType) map[string]MessageHandler) {
func (server *Server) listen(listener net.Listener) {
defer listener.Close()
for {
......@@ -149,14 +153,16 @@ func (server *Server) listen(listener net.Listener, acceptCallback func(ConnType
log.Printf("%s: closing connection: %s\n", logName, conn.RemoteAddr())
continue
}
var handler map[string]MessageHandler
if t == ConnTypeClient {
handler = server.clientHandlers
log.Printf("%s: accepted client connection: %s\n", logName, conn.RemoteAddr())
} else {
handler = server.serverHandlers
log.Printf("%s: accepted server connection: %s\n", logName, conn.RemoteAddr())
}
go server.handleConnection(newConn(t, tlsConn, sn, acceptCallback(t)))
go server.handleConnection(newConn(t, tlsConn, sn, handler))
}
}
......@@ -199,13 +205,27 @@ func (server *Server) checkConnection(conn *tls.Conn) (ConnType, election.Unique
func (server *Server) handleConnection(conn *Conn) {
defer conn.Close()
server.mutex.Lock()
server.Lock()
server.connections[conn] = true
server.mutex.Unlock()
server.Unlock()
messageLoop(conn)
server.mutex.Lock()
server.Lock()
delete(server.connections, conn)
server.mutex.Unlock()
server.Unlock()
}
// Register a message for handler for server connections
func (server *Server) OnServerMessage(message string, handler MessageHandler) {
server.Lock()
defer server.Unlock()
server.serverHandlers[message] = handler
}
// Register a message for client connections
func (server *Server) OnClientMessage(message string, handler MessageHandler) {
server.Lock()
defer server.Unlock()
server.clientHandlers[message] = handler
}
File mode changed from 100644 to 100755
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment