Commit 46f31fa3 authored by Anders Jensen Løvig's avatar Anders Jensen Løvig
Browse files

Merge branch 'agreement' into 'master'

Impl. agreement for good ballots

Closes #2 and #1

See merge request !1
parents 0227cb44 c566b5fe
Pipeline #21852 passed with stages
in 2 minutes and 24 seconds
......@@ -143,3 +143,15 @@ func (params *Params) MulProofs(proofs []Proof) Proof {
return result
}
func (proof Proof) Equals(other Proof) bool {
if len(other) != len(proof) {
return false
}
for i, e := range proof {
if other[i].Cmp(e) != 0 {
return false
}
}
return true
}
package election
import (
"errors"
"fmt"
"log"
"sync"
"github.com/gonum/stat/combin"
"github.com/google/uuid"
)
type BallotList map[string]uuid.UUID
type Agreement struct {
sync.Mutex
// The minimum servers required for being able to tally
required int
// Map of ballot lists receveid from servers
ourList BallotList
ballotLists map[UniqueID]BallotList
}
func NewAgreement(required int) *Agreement {
return &Agreement{
Mutex: sync.Mutex{},
required: required,
ballotLists: make(map[UniqueID]BallotList),
}
}
func (a *Agreement) AddList(serverID UniqueID, ballotList map[string]uuid.UUID) error {
a.Lock()
defer a.Unlock()
if _, ok := a.ballotLists[serverID]; ok {
return fmt.Errorf("server %s already sent ballot list", serverID)
}
a.ballotLists[serverID] = ballotList
return nil
}
func (a *Agreement) TallyList() ([]uuid.UUID, error) {
a.Lock()
defer a.Unlock()
// First create a list of received ballot lists
bls := make([]BallotList, 0, len(a.ballotLists))
for _, b := range a.ballotLists {
bls = append(bls, b)
}
// Find the maximum subset
hashes := MaxSubset(a.required, bls)
if len(hashes) == 0 {
return nil, errors.New("cannot agree on ballots")
}
// We know create a list of ballots we have in the subset.
list := make([]uuid.UUID, 0, len(a.ourList))
for h := range hashes {
if id, ok := a.ourList[h]; ok {
list = append(list, id)
delete(hashes, h) // Delete hashes that we have.
}
}
// If hashes is non-empty then we are not part of the max-subset.
if len(hashes) != 0 {
return nil, fmt.Errorf("missing %d ballots to tally", len(hashes))
}
return list, nil
}
func MaxSubset(t int, ballotLists []BallotList) map[string]bool {
n := len(ballotLists)
combis := combin.Combinations(n, t)
log.Printf("N: %d, combis: %d\n", n, len(combis))
subsets := make([]map[string]bool, len(combis), n)
for i, ls := range combis {
sets := make([]BallotList, 0, len(ls))
for _, i := range ls {
sets = append(sets, ballotLists[i])
}
subsets[i] = Intersection(sets)
}
max := 0
for i := 0; i < len(subsets); i++ {
if len(subsets[i]) > max {
max = i
}
}
return subsets[max]
}
func Intersection(superset []BallotList) map[string]bool {
section := make(map[string]int, len(superset[0]))
for _, set := range superset {
for hash := range set {
section[hash]++
}
}
set := make(map[string]bool)
for hash, num := range section {
if num >= len(superset) {
set[hash] = true
}
}
return set
}
package election
import (
"math/big"
"strconv"
"testing"
"github.com/google/uuid"
)
func TestIntersection(t *testing.T) {
N := 7
elem := make([]uuid.UUID, N)
for i := 0; i < N; i++ {
elem[i] = uuid.New()
}
superset := []BallotList{
{"0": elem[0], "1": elem[1], "2": elem[2], "3": elem[3]},
{"0": elem[0], "1": elem[1], "2": elem[2], "4": elem[4]},
{"0": elem[0], "1": elem[1], "2": elem[2], "5": elem[5]},
{"0": elem[0], "1": elem[1], "2": elem[2], "6": elem[6]},
}
subset := Intersection(superset)
if !subset["0"] {
t.Error()
}
if !subset["1"] {
t.Error()
}
if !subset["2"] {
t.Error()
}
if subset["3"] {
t.Error()
}
if subset["4"] {
t.Error()
}
if subset["5"] {
t.Error()
}
if subset["6"] {
t.Error()
}
}
func TestMaxSubset(t *testing.T) {
N := 7
elem := make([]uuid.UUID, N)
for i := 0; i < N; i++ {
elem[i] = uuid.New()
}
superset := []BallotList{
{"0": elem[0], "1": elem[1], "2": elem[2], "3": elem[3], "4": elem[4]},
{"0": elem[0], "1": elem[1], "2": elem[2], "4": elem[4]},
{"0": elem[0], "1": elem[1], "2": elem[2], "5": elem[5]},
{"0": elem[0], "1": elem[1], "2": elem[2], "4": elem[4], "6": elem[6]},
{"6": elem[6]},
}
subset := MaxSubset(4, superset)
if !subset["0"] {
t.Error()
}
if !subset["1"] {
t.Error()
}
if !subset["2"] {
t.Error()
}
if subset["3"] {
t.Error()
}
if subset["4"] {
t.Error()
}
if subset["5"] {
t.Error()
}
if subset["6"] {
t.Error()
}
}
func TestBallotIntersect(t *testing.T) {
lists := []BallotList{
map[string]uuid.UUID{},
map[string]uuid.UUID{},
map[string]uuid.UUID{},
}
xs := createXS(3)
ballots := createAllBallots([]*big.Int{
big.NewInt(0), big.NewInt(1), big.NewInt(0), big.NewInt(0),
big.NewInt(1), big.NewInt(0), big.NewInt(1), big.NewInt(0),
big.NewInt(0), big.NewInt(1),
}, 3, 2, xs)
for _, bmap := range ballots {
for id, b := range bmap {
i, _ := strconv.Atoi(string(id))
lists[i-1][b.Hash()] = b.ID
}
}
inter := Intersection(lists)
if len(inter) != 10 {
t.Error()
}
for i := 0; i < len(lists); i++ {
count := 0
for h := range lists[i] {
if _, ok := inter[h]; ok {
count++
}
}
if count != 10 {
t.Error(i)
}
}
}
func TestBallotMaxSubset(t *testing.T) {
lists := []BallotList{
map[string]uuid.UUID{},
map[string]uuid.UUID{},
map[string]uuid.UUID{},
}
xs := createXS(3)
ballots := createAllBallots([]*big.Int{
big.NewInt(0), big.NewInt(1), big.NewInt(0), big.NewInt(0),
big.NewInt(1), big.NewInt(0), big.NewInt(1), big.NewInt(0),
big.NewInt(0), big.NewInt(1),
}, 3, 2, xs)
for _, bmap := range ballots {
for id, b := range bmap {
i, _ := strconv.Atoi(string(id))
lists[i-1][b.Hash()] = b.ID
}
}
inter := MaxSubset(2, lists)
if len(inter) != 10 {
t.Error()
}
for i := 0; i < len(lists); i++ {
count := 0
for h := range lists[i] {
if _, ok := inter[h]; ok {
count++
}
}
if count != 10 {
t.Error(i)
}
}
}
......@@ -9,6 +9,7 @@ import (
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"log"
......@@ -53,6 +54,23 @@ func (ballot *Ballot) Verify(params *common.Params) bool {
return sigma.NewParams(params).Verify(ballot.Commits[0], ballot.Proofs)
}
func (ballot *Ballot) Hash() string {
// This information
toHash := struct {
ID uuid.UUID
Timestamp time.Time
Commits pedersen.Proof
}{
ballot.ID,
ballot.Timestamp,
ballot.Commits,
}
data, _ := json.Marshal(toHash)
h := sha256.Sum256(data)
return hex.EncodeToString(h[:])
}
// CreateBallots returns n ballots each containing a single share
func CreateBallots(t int, xs []*big.Int, vote *big.Int) map[UniqueID]*Ballot {
params := common.DefaultParams()
......@@ -61,11 +79,12 @@ func CreateBallots(t int, xs []*big.Int, vote *big.Int) map[UniqueID]*Ballot {
proof := sigma.NewParams(params).Prove(vote, binder, commits[0])
timestamp := time.Now()
id := uuid.New()
ballots := make(map[UniqueID]*Ballot)
for _, share := range shares {
ballots[UniqueID(share.X.String())] = &Ballot{
ID: uuid.New(),
ID: id,
Timestamp: timestamp,
Share: share,
Commits: commits,
......@@ -101,8 +120,35 @@ func (box *BallotBox) Put(ballot *Ballot) error {
return nil
}
// ballotList returns a mapping from ballot hash to their uuid.
func (box *BallotBox) ballotList() BallotList {
box.Lock()
defer box.Unlock()
list := make(map[string]uuid.UUID)
for id, ballot := range box.ballots {
list[ballot.Hash()] = id
}
return list
}
// Filter ballots by returning a list of valid ballots
func (box *BallotBox) Filter() []uuid.UUID {
good := []uuid.UUID{}
for id, ballot := range box.ballots {
if box.X.Cmp(ballot.Share.X) != 0 {
continue
}
if ballot.Verify(box.params) {
good = append(good, id)
}
}
return good
}
// Tally ballots
func (box *BallotBox) Tally() *Tally {
func (box *BallotBox) Tally(filter []uuid.UUID) *Tally {
box.Lock()
defer box.Unlock()
......@@ -112,7 +158,8 @@ func (box *BallotBox) Tally() *Tally {
commits := make([]pedersen.Proof, 0, size)
valids := 0
for _, ballot := range box.ballots {
for _, id := range filter {
ballot := box.ballots[id]
if box.X.Cmp(ballot.Share.X) == 0 {
if ballot.Verify(box.params) {
shares = append(shares, ballot.Share)
......@@ -120,8 +167,9 @@ func (box *BallotBox) Tally() *Tally {
valids++
}
} else {
log.Printf("Election: already contains ballot with x=%d\n", ballot.Share.X)
log.Printf("Election: ballot %s have x=%d\n", ballot.ID, ballot.Share.X)
}
}
log.Printf("Election: processed %d ballots with %d/%d valid", valids, valids, size)
......
......@@ -37,3 +37,22 @@ func TestBallotOne(t *testing.T) {
}
}
}
func TestBallotHash(t *testing.T) {
ballots := CreateBallots(5, []*big.Int{
big.NewInt(1),
big.NewInt(2),
big.NewInt(3),
big.NewInt(4),
big.NewInt(5),
}, big.NewInt(1))
h := ballots["1"].Hash()
t.Log("Hash", h)
for i, ballot := range ballots {
if ballot.Hash() != h {
t.Errorf("Ballot %s have hash %s", i, h)
}
}
}
......@@ -22,6 +22,8 @@ const (
ReasonCloseTimeout = "close timeout"
ReasonVotes = "all votes received"
ReasonAgreement = "majority"
ReasonTallyTimeout = "tally timeout"
ReasonBallotList = "all ballot lists received"
ReasonTally = "received enough tallies"
)
......@@ -31,9 +33,11 @@ type Phase int
// Election phases
const (
PhaseNotStarted Phase = iota
PhaseCollecting // When voters can send ballots
PhaseCollect // When voters can send ballots
PhaseCloseWait // When server waits after close
PhaseClosed // When waiting for other servers to close
PhasePreTally // When servers send ballot lists
PhaseTally // When servers sent tallies
PhaseTallying // When tallying
PhaseResult // When a result is available
)
......@@ -44,10 +48,12 @@ func (p Phase) String() string {
return "PhaseNotStarted"
case PhaseCloseWait:
return "PhaseCloseWait"
case PhaseCollecting:
case PhaseCollect:
return "PhaseCollecting"
case PhaseClosed:
return "PhaseClosed"
case PhasePreTally:
return "PhasePreTally"
case PhaseTallying:
return "PhaseTallying"
case PhaseResult:
......@@ -88,7 +94,7 @@ type Election struct {
Participants Participants
// The ballot box, which handles ballots received from voters
ballotBox *BallotBox
// The tally bix, which hanges tallies received from servers
// The tally box, which hanges tallies received from servers
tallyBox *TallyBox
// Result of the election
Result *Result
......@@ -97,6 +103,8 @@ type Election struct {
Deadline time.Time
// Struct containing server status about the election.
Status Status
// Agreement handles the agreement protocol. This decides which ballots to tally.
Agreement *Agreement
// Called when the final result is available
resultCallback func(*Result)
}
......@@ -151,14 +159,16 @@ func NewElection(config *Config) *Election {
},
resultCallback: config.ResultCallback,
Mutex: sync.Mutex{},
Agreement: NewAgreement(config.RequiredServers),
}
config.Server.OnClientMessage(MsgBallot, e.OnBallot)
config.Server.OnServerMessage(MsgClosing, e.OnClose)
config.Server.OnServerMessage(MsgTally, e.OnTally)
config.Server.OnServerMessage(MsgBallotList, e.OnBallotList)
return e
}
// Start the election, by opening for incomming ballots.
// Start the election, by opening for incoming ballots.
// Returns error if the deadline is not a future time.
func (election *Election) Start() error {
election.Lock()
......@@ -181,7 +191,7 @@ func (election *Election) Start() error {
defer election.Unlock()
// Deadline reached: Close election
if election.Status.Phase == PhaseCollecting {
if election.Status.Phase == PhaseCollect {
election.nextPhase(ReasonDeadline)
}
}(election)
......@@ -194,8 +204,8 @@ func (election *Election) nextPhase(reason Reason) {
switch election.Status.Phase {
case PhaseNotStarted:
log.Println("==================== Collecting ====================")
election.Status.Phase = PhaseCollecting
case PhaseCollecting:
election.Status.Phase = PhaseCollect
case PhaseCollect:
log.Println("====================== Closed ======================")
// The closed phase is twofold. If the reason for closing is another
// than all votes received (ReasonVotes), then we must wait for late
......@@ -212,24 +222,7 @@ func (election *Election) nextPhase(reason Reason) {
if reason != ReasonVotes && election.Config.CloseSleep != 0 {
log.Printf("Election: waiting %s for late ballots\n", election.Config.CloseSleep)
election.Status.Phase = PhaseCloseWait
go func(e *Election) {
var reason Reason
// Wait for timeout or all votes received.
select {
case <-time.After(time.Until(now.Add(e.Config.CloseSleep))):
reason = ReasonCloseTimeout
case <-e.Status.PhaseChannel:
reason = ReasonVotes
}
// Then next phase
e.Lock()
defer e.Unlock()
// Actually close the election
e.Status.Phase = PhaseClosed
e.nextPhase(reason)
}(election)
go delay(election, now, ReasonCloseTimeout, ReasonVotes, PhaseClosed)
break
} else {
// If we received all votes go directly to next phase
......@@ -245,29 +238,51 @@ func (election *Election) nextPhase(reason Reason) {
break
}
// Else enough servers to start tallying
// If the ballot box is empty we can go directly to result
log.Println("===================== Tallying =====================")
// Tallying is two phased. First we wait for servers to sent ballot lists.
// Then we find the largest intersection of ballot lists and tally this intersection.
ballotList := election.ballotBox.ballotList()
election.Agreement.ourList = ballotList
_ = election.Agreement.AddList(election.ServerID, ballotList)
election.server.Broadcast(network.NewMessage(MsgBallotList, ballotList))
if len(election.Agreement.ballotLists) < election.Participants.Servers && election.Config.CloseSleep != 0 {
election.Status.Phase = PhasePreTally
log.Printf("Election: waiting %s for ballots lists\n", election.Config.CloseSleep)
go delay(election, time.Now(), ReasonTallyTimeout, ReasonBallotList, PhaseTally)
break
}
fallthrough
case PhaseTally:
election.Status.Phase = PhaseTallying
// Only tally if we have ballots
if election.ballotBox.Size() != 0 {
log.Println("===================== Tallying =====================")
election.Status.Phase = PhaseTallying
tally := election.ballotBox.Tally()
election.Status.TalliedServers[election.ServerID] = true
err := election.tallyBox.Put(tally)
if err != nil {
panic(err) // This should never happen
list, err := election.Agreement.TallyList()
if err == nil {
tally := election.ballotBox.Tally(list)
election.Status.TalliedServers[election.ServerID] = true
err = election.tallyBox.Put(tally)
if err != nil {
panic(err) // Should not happen
}
election.server.Broadcast(network.NewMessage(MsgTally, tally))
break
}
election.server.Broadcast(network.NewMessage(MsgTally, tally))
break
log.Println("Should really wait for tallies from other servers !!! :D")
}
fallthrough
case PhaseTallying:
log.Println("====================== Result ======================")
election.Status.Phase = PhaseResult
result := election.createResults()
election.resultCallback(result)
default:
return // No new phase
}
log.Println("Election: new phase:", election.Status.Phase)
}
func (election *Election) createResults() *Result {
......@@ -280,3 +295,21 @@ func (election *Election) createResults() *Result {
Tally: int(election.tallyBox.Combine().Int64()),
}
}
func delay(e *Election, now time.Time, r1, r2 Reason, p Phase) {
var reason Reason
// Wait for timeout or all votes received.
select {
case <-time.After(time.Until(now.Add(e.Config.CloseSleep))):
reason = r1
case <-e.Status.PhaseChannel:
reason = r2
}
// Then next phase
e.Lock()
defer e.Unlock()