Commit c5f97d4d authored by Anders Jensen Løvig's avatar Anders Jensen Løvig
Browse files

Move election message handling

parent 293d84d7
......@@ -66,30 +66,6 @@ func init() {
server.Config.CertPath = fmt.Sprintf("config/cert/server/s%d.crt", serialNumber)
server.Config.KeyPath = fmt.Sprintf("config/cert/server/s%d.key", serialNumber)
server.Server = network.NewServer(server.Config)
server.Server.OnClientMessage(election.MsgBallot, func(conn *network.Conn, data network.MessageData) *network.Message {
var ballot election.Ballot
err := data.Unmarshal(&ballot)
if err != nil {
log.Println("Main: error:", err)
} else {
go server.Election.HandleBallot(conn.SerialNumber(), &ballot)
}
return network.NewMessage(election.MsgBallotAck, nil)
})
server.Server.OnServerMessage(election.MsgClosing, func(conn *network.Conn, data network.MessageData) *network.Message {
server.Election.HandleClosing(conn.SerialNumber())
return nil
})
server.Server.OnServerMessage(election.MsgTally, func(conn *network.Conn, data network.MessageData) *network.Message {
var tally election.Tally
err := data.Unmarshal(&tally)
if err != nil {
log.Println("Main: error:", err)
} else {
go server.Election.HandleTally(conn.SerialNumber(), &tally)
}
return nil
})
// Initialize election
servers := len(server.Config.Servers)
......@@ -105,6 +81,10 @@ func init() {
ResultCallback: resultCallback,
}
server.Election = election.NewElection(config)
server.Server.OnClientMessage(election.MsgBallot, server.Election.OnBallot)
server.Server.OnServerMessage(election.MsgClosing, server.Election.OnClose)
server.Server.OnServerMessage(election.MsgTally, server.Election.OnTally)
}
func startServer() {
......@@ -140,7 +120,7 @@ func startElection() {
func main() {
startServer()
time.Sleep(2 * time.Second) // Give servers time to start before connecting
time.Sleep(5 * time.Second) // Give servers time to start before connecting
conenctServers()
time.Sleep(1 * time.Second) // To let servers connect before starting election
startElection()
......
......@@ -54,7 +54,7 @@ func init() {
// Init client
voter.RecievedCount = 0
voter.Client = network.NewClient(config)
voter.Client.OnMessage(election.MsgBallotAck, func(conn *network.Conn, data network.MessageData) *network.Message {
voter.Client.OnMessage(election.MsgBallotAck, func(conn *network.Conn, data network.MessageData) {
voter.Lock()
defer voter.Unlock()
......@@ -65,7 +65,6 @@ func init() {
if voter.RecievedCount == voter.ConnectionCount {
voter.Done <- true
}
return nil
})
voter.Mutex = sync.Mutex{}
......@@ -99,7 +98,7 @@ func main() {
sendBallots(ballots)
// Wait for acknowledgements or timeout
timeout := network.TimeoutAfter(10 * time.Second)
timeout := time.After(10 * time.Second)
select {
case <-voter.Done:
log.Println("Main: all servers acknowledged")
......@@ -149,7 +148,7 @@ func sendBallots(ballots map[election.UniqueID]*election.Ballot) {
log.Println("Main: error:", err)
continue
}
ballot := ballots[conn.SerialNumber()]
ballot := ballots[election.UniqueID(conn.SerialNumber())]
message := network.NewMessage(election.MsgBallot, ballot)
err = conn.WriteMessage(message)
......
......@@ -13,14 +13,6 @@ import (
// UniqueID represents a unique id assigned to each server and voter.
type UniqueID string
// Election message types
const (
MsgBallot = "ballot"
MsgBallotAck = "ballot_ack"
MsgClosing = "election_closing"
MsgTally = "election_tally"
)
type Reason string
const (
......@@ -231,84 +223,6 @@ func (election *Election) nextPhase(reason Reason) {
}
}
// HandleBallot puts ballot in the ballot box if ballot is valid.
func (election *Election) HandleBallot(voterID UniqueID, ballot *Ballot) {
election.Lock()
defer election.Unlock()
if election.Status.Phase == PhaseCollecting {
if _, ok := election.Status.Turnout[voterID]; ok {
log.Printf("Election: error: voter %s tried to vote again\n", voterID)
return
}
err := election.ballotBox.Put(ballot)
if err != nil {
log.Println("Election: error:", err)
return
}
election.Status.Turnout[voterID] = true
if election.ballotBox.Size() >= election.Participants.Voters {
election.nextPhase(ReasonVotes)
}
} else {
log.Printf("Election: error: recevied ballot %s outside of collecting phase\n", ballot.ID)
}
}
// HandleClosing handles that a server with id closed.
func (election *Election) HandleClosing(id UniqueID) {
election.Lock()
defer election.Unlock()
if election.Status.Phase == PhaseNotStarted {
log.Printf("Election: error: server %s closed before election started\n", id)
return
}
if _, ok := election.Status.ClosedServers[id]; ok {
log.Printf("Election: error: server %s already closed\n", id)
return
}
election.Status.ClosedServers[id] = true
if len(election.Status.ClosedServers) >= election.Participants.RequiredServers &&
election.Status.Phase < PhaseTallying {
election.nextPhase(ReasonAgreement)
}
}
// HandleTally handles the tallies received from servers
func (election *Election) HandleTally(serverID UniqueID, tally *Tally) {
election.Lock()
defer election.Unlock()
if election.Status.Phase == PhaseNotStarted {
log.Printf("Election: error: server %s tallied before election started\n", serverID)
return
}
if _, ok := election.Status.ClosedServers[serverID]; !ok {
log.Printf("Election: error: received tally from server %s, who have not closed\n", serverID)
return
}
if _, ok := election.Status.TalliedServers[serverID]; ok {
log.Printf("Election: error: server %s have already sent a tally\n", serverID)
return
}
err := election.tallyBox.Put(tally)
if err != nil {
log.Println("Election: error:", err)
return
}
election.Status.TalliedServers[serverID] = true
if len(election.Status.TalliedServers) >= election.Participants.RequiredServers &&
election.Status.Phase < PhaseResult {
election.nextPhase(ReasonTally)
}
}
func (election *Election) createResults() *Result {
if election.tallyBox.Size() == 0 {
return nil
......
......@@ -252,10 +252,10 @@ func TestCloseDeadline(t *testing.T) {
func TestCloseAgreement(t *testing.T) {
_ = testCloseCondition(t, ReasonAgreement, PhaseTallying, func(e *Election) {
ballots := CreateBallots(3, createXS(1), big.NewInt(1))
e.HandleBallot(UniqueID(strconv.Itoa(1)), ballots["1"])
e.handleBallot(UniqueID(strconv.Itoa(1)), ballots["1"])
e.HandleClosing("2")
e.HandleClosing("3")
e.handleClosing("2")
e.handleClosing("3")
})
}
......@@ -263,7 +263,7 @@ func TestCloseAllVotes(t *testing.T) {
_ = testCloseCondition(t, ReasonVotes, PhaseClosed, func(e *Election) {
for i := 1; i <= e.Participants.Voters; i++ {
ballots := CreateBallots(3, createXS(1), big.NewInt(1))
e.HandleBallot(UniqueID(strconv.Itoa(i)), ballots["1"])
e.handleBallot(UniqueID(strconv.Itoa(i)), ballots["1"])
}
})
}
......@@ -271,11 +271,11 @@ func TestCloseAllVotes(t *testing.T) {
func TestTallyDeadline(t *testing.T) {
_ = testCloseCondition(t, ReasonDeadline, PhaseTallying, func(e *Election) {
ballots := CreateBallots(3, createXS(1), big.NewInt(1))
e.HandleBallot(UniqueID(strconv.Itoa(1)), ballots["1"])
e.handleBallot(UniqueID(strconv.Itoa(1)), ballots["1"])
// By closing another server, we should go directly to tallying when
// deadline is reached
e.HandleClosing("2")
e.handleClosing("2")
})
}
......@@ -294,7 +294,7 @@ func TestNoVotesResult(t *testing.T) {
timeout <- true
}()
e.HandleClosing("2")
e.handleClosing("2")
select {
case <-done:
......@@ -309,9 +309,9 @@ func TestNoVotesResult(t *testing.T) {
func TestTallyAfterClose(t *testing.T) {
_ = testCloseCondition(t, ReasonDeadline, PhaseTallying, func(e *Election) {
ballots := CreateBallots(3, createXS(1), big.NewInt(1))
e.HandleBallot(UniqueID(strconv.Itoa(1)), ballots["1"])
e.handleBallot(UniqueID(strconv.Itoa(1)), ballots["1"])
e.HandleClosing("2")
e.handleClosing("2")
})
}
......@@ -321,8 +321,8 @@ func TestOnlyVoteOnce(t *testing.T) {
ballots := CreateBallots(3, createXS(1), big.NewInt(1))
e.HandleBallot("4", ballots["1"])
e.HandleBallot("4", ballots["1"])
e.handleBallot("4", ballots["1"])
e.handleBallot("4", ballots["1"])
if e.ballotBox.Size() != 1 {
t.Errorf("Expected ballot box size 1, got %d", e.ballotBox.Size())
......@@ -343,23 +343,23 @@ func TestHandleTally(t *testing.T) {
}
_ = testCloseCondition(t, ReasonAgreement, PhaseResult, func(e *Election) {
e.HandleBallot("4", ballots["1"])
e.HandleTally("2", tally1)
e.handleBallot("4", ballots["1"])
e.handleTally("2", tally1)
if e.tallyBox.Size() != 0 {
t.Errorf("Expected 0 tally, got %d", e.tallyBox.Size())
}
e.HandleClosing("2")
e.handleClosing("2")
e.HandleTally("2", tally1)
e.HandleTally("2", tally1)
e.handleTally("2", tally1)
e.handleTally("2", tally1)
if e.tallyBox.Size() != 1 {
t.Errorf("Expected 1 tally, got %d", e.tallyBox.Size())
}
fmt.Println("Phase:", e.Status.Phase)
e.HandleClosing("3")
e.handleClosing("3")
fmt.Println("Phase:", e.Status.Phase)
e.HandleTally("3", tally2)
e.handleTally("3", tally2)
fmt.Println("Phase:", e.Status.Phase)
})
}
package election
import (
"bsc-shamir/network"
"log"
)
// This file implements network message handlers used by the election
const (
// Voter -> Server: contains a ballot
MsgBallot = "ballot"
// Server -> Voter: Acknowledge ballot
MsgBallotAck = "ballot_ack"
// Server -> Server: Sent when server is closing
MsgClosing = "election_closing"
// Server -> Server: Sent when server have tallied
MsgTally = "election_tally"
)
func (e *Election) OnBallot(conn *network.Conn, data network.MessageData) {
var ballot Ballot
err := data.Unmarshal(&ballot)
if err != nil {
log.Println("Election: error:", err)
return
}
// Use the connections certificate serial number as the voter id.
voterID := UniqueID(conn.SerialNumber())
go e.handleBallot(voterID, &ballot)
err = conn.WriteMessage(network.NewMessage(MsgBallotAck, nil))
if err != nil {
log.Println("Election: error:", err)
}
}
func (e *Election) handleBallot(voterID UniqueID, ballot *Ballot) {
e.Lock()
defer e.Unlock()
if e.Status.Phase == PhaseCollecting {
if _, ok := e.Status.Turnout[voterID]; ok {
log.Printf("Election: error: voter %s tried to vote again\n", voterID)
return
}
err := e.ballotBox.Put(ballot)
if err != nil {
log.Println("Election: error:", err)
return
}
e.Status.Turnout[voterID] = true
if e.ballotBox.Size() >= e.Participants.Voters {
e.nextPhase(ReasonVotes)
}
} else {
log.Printf("Election: error: recevied ballot %s outside of collecting phase\n", ballot.ID)
}
}
func (e *Election) OnClose(conn *network.Conn, data network.MessageData) {
e.handleClosing(UniqueID(conn.SerialNumber()))
}
// handleClosing handles that a server with id closed.
func (election *Election) handleClosing(id UniqueID) {
election.Lock()
defer election.Unlock()
if election.Status.Phase == PhaseNotStarted {
log.Printf("Election: error: server %s closed before election started\n", id)
return
}
if _, ok := election.Status.ClosedServers[id]; ok {
log.Printf("Election: error: server %s already closed\n", id)
return
}
election.Status.ClosedServers[id] = true
if len(election.Status.ClosedServers) >= election.Participants.RequiredServers &&
election.Status.Phase < PhaseTallying {
election.nextPhase(ReasonAgreement)
}
}
func (e *Election) OnTally(conn *network.Conn, data network.MessageData) {
var tally Tally
err := data.Unmarshal(&tally)
if err != nil {
log.Println("Election: error:", err)
}
serverID := UniqueID(conn.SerialNumber())
go e.handleTally(serverID, &tally)
}
// handleTally handles the tallies received from servers
func (election *Election) handleTally(serverID UniqueID, tally *Tally) {
election.Lock()
defer election.Unlock()
if election.Status.Phase == PhaseNotStarted {
log.Printf("Election: error: server %s tallied before election started\n", serverID)
return
}
if _, ok := election.Status.ClosedServers[serverID]; !ok {
log.Printf("Election: error: received tally from server %s, who have not closed\n", serverID)
return
}
if _, ok := election.Status.TalliedServers[serverID]; ok {
log.Printf("Election: error: server %s have already sent a tally\n", serverID)
return
}
err := election.tallyBox.Put(tally)
if err != nil {
log.Println("Election: error:", err)
return
}
election.Status.TalliedServers[serverID] = true
if len(election.Status.TalliedServers) >= election.Participants.RequiredServers &&
election.Status.Phase < PhaseResult {
election.nextPhase(ReasonTally)
}
}
package network
import (
"bsc-shamir/election"
"crypto/tls"
"crypto/x509"
"errors"
......@@ -89,7 +88,7 @@ func (client *Client) Connect(address string) (*Conn, error) {
// Checks that the connection is to a server and returns the serial number.
// If the connection is not to a server, then an error is returned.
func (client *Client) checkConnection(address string, conn *tls.Conn) (election.UniqueID, error) {
func (client *Client) checkConnection(address string, conn *tls.Conn) (string, error) {
verifiedChains := conn.ConnectionState().VerifiedChains
if len(verifiedChains) != 1 {
return "", errors.New("multiple verified chains")
......@@ -107,7 +106,7 @@ func (client *Client) checkConnection(address string, conn *tls.Conn) (election.
return "", fmt.Errorf("unexpected serial number %s", cert.SerialNumber)
}
return election.UniqueID(cert.SerialNumber.String()), nil
return cert.SerialNumber.String(), nil
}
// Close all connections. An error is returned if some connections failed.
......
package network
import (
"bsc-shamir/election"
"bufio"
"encoding/binary"
"io"
......@@ -24,7 +23,7 @@ type Conn struct {
// serialNumber is a uniqely indentifier for remote party. No two connections
// will have the same serialNumber, unless they are connected to the same
// participant.
serialNumber election.UniqueID
serialNumber string
// Connection stuff
reader *bufio.Reader
......@@ -42,7 +41,7 @@ type Conn struct {
messageHandlers map[string]MessageHandler
}
func newConn(connType ConnType, conn net.Conn, serialNumber election.UniqueID, handlers map[string]MessageHandler) *Conn {
func newConn(connType ConnType, conn net.Conn, serialNumber string, handlers map[string]MessageHandler) *Conn {
return &Conn{
connType: connType,
conn: conn,
......@@ -57,7 +56,7 @@ func newConn(connType ConnType, conn net.Conn, serialNumber election.UniqueID, h
}
}
func (c *Conn) SerialNumber() election.UniqueID {
func (c *Conn) SerialNumber() string {
return c.serialNumber
}
......
......@@ -6,12 +6,10 @@ import (
"encoding/json"
"io"
"log"
"net"
"sync"
)
// MessageHandler maps commands to a message handler callback
type MessageHandler func(*Conn, MessageData) *Message
type MessageHandler func(*Conn, MessageData)
// MessageData represents the data carried by a message.
type MessageData json.RawMessage
......@@ -60,26 +58,6 @@ func readMessage(reader *bufio.Reader) (*Message, error) {
return message, err
}
func writeMessage(conn net.Conn, message *Message) error {
log.Printf("%s: sending '%s' to %s", logName, message.Command, conn.RemoteAddr())
data := marshal(message)
sizeBuffer := make([]byte, 4)
binary.BigEndian.PutUint32(sizeBuffer, uint32(len(data)))
_, err := conn.Write(sizeBuffer)
if err != nil {
return err
}
_, err = conn.Write(data)
if err != nil {
return err
}
return nil
}
func marshal(v interface{}) []byte {
data, err := json.Marshal(v)
if err != nil {
......@@ -97,8 +75,6 @@ func unmarshal(data []byte, v interface{}) error {
}
func messageLoop(conn *Conn) {
mutex := sync.Mutex{}
reader := bufio.NewReader(conn)
for {
// First wait for a message and check the returned error.
......@@ -118,17 +94,7 @@ func messageLoop(conn *Conn) {
if handler, ok := conn.messageHandlers[message.Command]; ok {
// Spawn go-routine so the message loop can listen for the next message
log.Printf("%s: received message '%s' from %s\n", logName, message.Command, conn.RemoteAddr())
go func(handler MessageHandler) {
response := handler(conn, message.Data)
if response != nil {
mutex.Lock()
err := writeMessage(conn, response)
if err != nil {
log.Printf("%s: error: %s\n", logName, err)
}
mutex.Unlock()
}
}(handler)
go handler(conn, message.Data)
} else {
// Zero tolerance for faulty servers. Just close the connection.
log.Printf("%s: error: %s sent unknown message: %s\n", logName, conn.RemoteAddr(), message.Command)
......
package network
import (
"bsc-shamir/election"
"crypto/tls"
"crypto/x509"
"errors"
......@@ -168,7 +167,7 @@ func (server *Server) listen(listener net.Listener) {
// Checks that the connection is to a server and returns the serial number.
// If the connection is not to a server, then an error is returned.
func (server *Server) checkConnection(conn *tls.Conn) (ConnType, election.UniqueID, error) {
func (server *Server) checkConnection(conn *tls.Conn) (ConnType, string, error) {
verifiedChains := conn.ConnectionState().VerifiedChains
if len(verifiedChains) == 0 {
return 0, "", errors.New("no verified chains")
......@@ -196,10 +195,10 @@ func (server *Server) checkConnection(conn *tls.Conn) (ConnType, election.Unique
return 0, "", err
}
return ConnTypeClient, election.UniqueID(cert.SerialNumber.String()), nil
return ConnTypeClient, cert.SerialNumber.String(), nil
}
return ConnTypeServer, election.UniqueID(cert.SerialNumber.String()), nil
return ConnTypeServer, cert.SerialNumber.String(), nil
}
func (server *Server) handleConnection(conn *Conn) {
......
......@@ -2,7 +2,6 @@ package network
import (
"net"
"time"
)
var logName string
......@@ -17,13 +16,3 @@ func findLocalAddress() (string, error) {
localAddr := conn.LocalAddr().(*net.UDPAddr)
return localAddr.IP.String(), nil
}
// TimeoutAfter returns a channel, which timeout after specified duration
func TimeoutAfter(duration time.Duration) chan bool {
channel := make(chan bool, 1)
go func() {
time.Sleep(duration)
channel <- true
}()
return channel
}
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment