Commit 91d638fc authored by Anders Jensen Løvig's avatar Anders Jensen Løvig
Browse files

It just works

parent 40c9c061
...@@ -5,7 +5,7 @@ import ( ...@@ -5,7 +5,7 @@ import (
"crycomp/internal/blood" "crycomp/internal/blood"
"crycomp/internal/crypto/garbled" "crycomp/internal/crypto/garbled"
"crycomp/internal/crypto/oblivious" "crycomp/internal/crypto/oblivious"
"crycomp/internal/util" "crycomp/internal/crypto/util"
"fmt" "fmt"
) )
...@@ -69,20 +69,22 @@ func NewParty(in chan []byte, out chan []byte, obliv *oblivious.Party) (p *Party ...@@ -69,20 +69,22 @@ func NewParty(in chan []byte, out chan []byte, obliv *oblivious.Party) (p *Party
func (alice *Party) RunAlice(x int) (result bool, err error) { func (alice *Party) RunAlice(x int) (result bool, err error) {
// 1. Garble // 1. Garble
// F, _ := garbled.NewTable(NumWires-NumInputs*2, 4, K*2)
alice.circuit.F.SetData(<-alice.receive) alice.circuit.F.SetData(<-alice.receive)
fmt.Println("Alice: Received F from Bob")
// 2. Encode Bob's input // 2. Encode Bob's input
Y := <-alice.receive Y := <-alice.receive
fmt.Println("Alice: Received Y from Bob")
// 3. Encode Alice's input // 3. Encode Alice's input
X, err := alice.obliv.Receive(x) X, err := alice.obliv.Receive(x)
fmt.Println("Alice: Oblivious transfer encoding from Bob")
if err != nil { if err != nil {
return return
} }
// Since Elgamal encryption might have removed leading zero bytes,
// we need to add them again
vLen := K / 8
if m := len(X) % vLen; m != 0 {
X = append(make([]byte, vLen-m), X...)
}
// 4. Evaluation // 4. Evaluation
evalInput := append(X, Y...) evalInput := append(X, Y...)
...@@ -90,15 +92,14 @@ func (alice *Party) RunAlice(x int) (result bool, err error) { ...@@ -90,15 +92,14 @@ func (alice *Party) RunAlice(x int) (result bool, err error) {
if err != nil { if err != nil {
return return
} }
fmt.Println("Alice: Circuit evaluated")
// 5b. Alice Output // 5b. Alice Output
alice.circuit.D, _ = garbled.NewTable(1, 2, K) alice.circuit.D, err = garbled.NewTable(1, 2, K)
if err != nil {
return
}
alice.circuit.D.SetData(<-alice.receive) alice.circuit.D.SetData(<-alice.receive)
result, err = garbled.Decode(alice.circuit.D, Z) return garbled.Decode(alice.circuit.D, Z)
fmt.Println("Alice: Decoded result")
return
} }
func (bob *Party) RunBob(y int) (err error) { func (bob *Party) RunBob(y int) (err error) {
...@@ -107,12 +108,10 @@ func (bob *Party) RunBob(y int) (err error) { ...@@ -107,12 +108,10 @@ func (bob *Party) RunBob(y int) (err error) {
if err != nil { if err != nil {
return return
} }
fmt.Println("Bob: Sending F to Alice")
bob.send <- bob.circuit.F.GetData() bob.send <- bob.circuit.F.GetData()
// 2. Encode Bob's input // 2. Encode Bob's input
Y := garbled.Encode(bob.circuit.E[1], util.Int2Bools(y, 3)) Y := garbled.Encode(bob.circuit.E[1], util.Int2Bools(y, 3))
fmt.Println("Bob: Sending Y to Alice")
bob.send <- Y bob.send <- Y
// 3. Encode Alice's input // 3. Encode Alice's input
...@@ -121,7 +120,6 @@ func (bob *Party) RunBob(y int) (err error) { ...@@ -121,7 +120,6 @@ func (bob *Party) RunBob(y int) (err error) {
for i := 0; i < len(data); i++ { for i := 0; i < len(data); i++ {
data[i] = garbled.Encode(bob.circuit.E[0], util.Int2Bools(i, 3)) data[i] = garbled.Encode(bob.circuit.E[0], util.Int2Bools(i, 3))
} }
fmt.Println("Bob: Oblivious transfer encoding to Alice")
err = bob.obliv.Send(data) err = bob.obliv.Send(data)
if err != nil { if err != nil {
return return
...@@ -131,42 +129,11 @@ func (bob *Party) RunBob(y int) (err error) { ...@@ -131,42 +129,11 @@ func (bob *Party) RunBob(y int) (err error) {
// do nothing // do nothing
// 5b. Alice output // 5b. Alice output
fmt.Println("Bob: Sending d to Alice")
bob.send <- bob.circuit.D.GetData() bob.send <- bob.circuit.D.GetData()
return return
} }
// func (p *Party) OTSend() {
// for i := 0; i < 3; i++ {
// // p.Listen(p.Circuit.E[i+3])
// p.Listen(nil)
// }
// }
// func (p *Party) OTReceive() (res [][]byte, err error) {
// res = make([][]byte, 3)
// var val int
// for i := range p.x {
// if p.x[i] {
// val = 1
// } else {
// val = 0
// }
// var result, err = p.ObliviousTransfer(2, val)
// if err != nil {
// return nil, err
// }
// res[i] = result
// }
// return res, nil
// }
// func OTransfer(sender, receiver *Party) {
// sender.Party, receiver.Party = oblivious.NewPartyPair()
// //sender.Listen(sender.Circuit.e)
// }
// RunProtocol runs the protocol between receiving blood type x and donor blood // RunProtocol runs the protocol between receiving blood type x and donor blood
// type y. // type y.
// outputMode: 0: Alice learns, 1: Bob learns else both learns // outputMode: 0: Alice learns, 1: Bob learns else both learns
...@@ -186,72 +153,10 @@ func RunProtocol(x, y int, outputMode int) (z bool, err error) { ...@@ -186,72 +153,10 @@ func RunProtocol(x, y int, outputMode int) (z bool, err error) {
z, err = p.A.RunAlice(x) z, err = p.A.RunAlice(x)
return return
// go func(bob *Party) {
// 3. Encode Alice's input
// Bob.InChan <- Bob.EncInput
// fmt.Printf("Bob has garbled his input and sent it to alice\n")
// // 3. Encode Alices input
// Bob.OTSend()
// fmt.Printf("Bob has finished garbling alices input\n")
// // 5. Output
// if outputMode != 1 { // Alice learns
// // Bob sends d
// Bob.InChan <- Bob.Circuit.D
// fmt.Printf("Bob sends D so Alice can learn\n")
// }
// if outputMode != 0 { // Bob learns
// // Alice sends Z
// Bob.result = garbled.Decode(Bob.Circuit.D, <-Bob.outChan)
// fmt.Printf("Bob has received Z so he can learn\n")
// }
// }(bob)
//go func (Alice *Party) {
// 1. Garble
// F := <-alice.TableChan
// alice.Circuit = &garbled.Circuit{F: &F}
// fmt.Printf("Alice has received F\n")
// // 2. Encode Bobs input
// alice.EncInput = <-alice.InChan
// // 3. Encode Alices input
// X, _ := alice.OTReceive()
// alice.EncInput = append(alice.EncInput, X...)
// fmt.Printf("Alice is finished garbling her input\n")
// // 4. Evaluation
// Z, eval_err := garbled.Eval(alice.Circuit.F, alice.EncInput)
// if eval_err != nil {
// return false, eval_err
// }
// fmt.Printf("Alice has evaluated the circuit\n")
// // 5. Output
// if outputMode != 1 { // Alice learns
// // Bob sends d
// alice.result = garbled.Decode(<-alice.InChan, Z)
// fmt.Printf("Alice has learned the output\n")
// }
// if outputMode != 0 { // Bob learns
// // Alice sends Z
// alice.outChan <- Z
// fmt.Printf("Alice has sent Z to bob\n")
// }
// //}(Alice)
// if outputMode != 1 {
// return alice.result, nil
// } else {
// return Bob.result, nil
// }
// return false, nil
} }
func main() { func main() {
bloodA, bloodB := blood.Type_ABn, blood.Type_ABp bloodA, bloodB := blood.Type_ABn, blood.Type_ABn
z, err := RunProtocol(bloodA, bloodB, 0) z, err := RunProtocol(bloodA, bloodB, 0)
if err != nil { if err != nil {
fmt.Println("Protocol failed with error:", err) fmt.Println("Protocol failed with error:", err)
......
...@@ -12,6 +12,7 @@ func TestProtocol(t *testing.T) { ...@@ -12,6 +12,7 @@ func TestProtocol(t *testing.T) {
for x := 0; x < n; x++ { for x := 0; x < n; x++ {
for y := 0; y < n; y++ { for y := 0; y < n; y++ {
// util.DebugWriter.Reset()
testName := fmt.Sprintf("(x=%s,y=%s)", blood.Names[x], blood.Names[y]) testName := fmt.Sprintf("(x=%s,y=%s)", blood.Names[x], blood.Names[y])
t.Run(testName, func(t *testing.T) { t.Run(testName, func(t *testing.T) {
z, err := RunProtocol(x, y, 0) z, err := RunProtocol(x, y, 0)
...@@ -24,28 +25,3 @@ func TestProtocol(t *testing.T) { ...@@ -24,28 +25,3 @@ func TestProtocol(t *testing.T) {
} }
} }
} }
func TestProtocolInSeries(t *testing.T) {
// Runs the protocol for all combinations of recipient and donor blood types.
n := len(blood.Table)
for x := 0; x < n; x++ {
for y := 0; y < n; y++ {
z, err := RunProtocol(x, y, 0)
if err != nil {
t.Errorf("Protocol error: %s", err)
} else if z != blood.Table[x][y] {
t.Fatalf("Failed blood compatibility test for index [%d,%d]. Expected %t, got %t", x, y, !z, z)
}
}
}
}
func TestProtocolFirst(t *testing.T) {
z, err := RunProtocol(0, 0, 0)
if err != nil {
t.Errorf("Protocol error: %s", err)
} else if z != blood.Table[0][0] {
t.Fatalf("Failed blood compatibility test for index [0,0]. Expected %t, got %t", !z, z)
}
}
...@@ -2,8 +2,7 @@ package garbled ...@@ -2,8 +2,7 @@ package garbled
import ( import (
"bytes" "bytes"
cryptoUtil "crycomp/internal/crypto/util" "crycomp/internal/crypto/util"
"crycomp/internal/util"
"crypto/sha256" "crypto/sha256"
"errors" "errors"
"fmt" "fmt"
...@@ -56,14 +55,23 @@ func (c *Circuit) GarbleBloodCircuit() (err error) { ...@@ -56,14 +55,23 @@ func (c *Circuit) GarbleBloodCircuit() (err error) {
// Create e // Create e
c.E = make([]*Table, 2) c.E = make([]*Table, 2)
c.E[0], _ = NewTable(c.numInputs/2, 2, c.k) c.E[0], err = NewTable(c.numInputs/2, 2, c.k)
if err != nil {
return
}
c.E[0].SetData(kTable.getRows(0, c.numInputs/2)) c.E[0].SetData(kTable.getRows(0, c.numInputs/2))
c.E[1], _ = NewTable(c.numInputs/2, 2, c.k) c.E[1], err = NewTable(c.numInputs/2, 2, c.k)
if err != nil {
return
}
c.E[1].SetData(kTable.getRows(c.numInputs/2, c.numInputs/2)) c.E[1].SetData(kTable.getRows(c.numInputs/2, c.numInputs/2))
// Create d // Create d
c.D, _ = NewTable(1, 2, c.k) c.D, err = NewTable(1, 2, c.k)
if err != nil {
return
}
c.D.SetData(kTable.getRow(c.numWires - 1)) c.D.SetData(kTable.getRow(c.numWires - 1))
return return
...@@ -99,7 +107,7 @@ func Decode(d *Table, Z []byte) (bool, error) { ...@@ -99,7 +107,7 @@ func Decode(d *Table, Z []byte) (bool, error) {
//Creates a row (C_0^i, C_1^i, C_2^i ,C_3^i) for the garbled table, where //Creates a row (C_0^i, C_1^i, C_2^i ,C_3^i) for the garbled table, where
func (c *Circuit) garbleGate(K *Table, Li, Ri, out int, gateFun func(a, b int) int) []byte { func (c *Circuit) garbleGate(K *Table, Li, Ri, out int, gateFun func(a, b int) int) []byte {
cRow := make([]byte, 4*2*K.valueLen) // because 4 values of double length cRow := make([]byte, 4*2*K.valueLen) // because 4 values of double length
perm := cryptoUtil.Perm(4) // Random permutation perm := util.Perm(4) // Random permutation
// for (a,b) in {0,1} x {0,1} // for (a,b) in {0,1} x {0,1}
for a := 0; a <= 1; a++ { for a := 0; a <= 1; a++ {
...@@ -110,7 +118,6 @@ func (c *Circuit) garbleGate(K *Table, Li, Ri, out int, gateFun func(a, b int) i ...@@ -110,7 +118,6 @@ func (c *Circuit) garbleGate(K *Table, Li, Ri, out int, gateFun func(a, b int) i
copy(right[:K.valueLen], K.getValue(out, gateFun(a, b))) copy(right[:K.valueLen], K.getValue(out, gateFun(a, b)))
dst := util.XOR(left, right) // XOR with left as destination. dst := util.XOR(left, right) // XOR with left as destination.
// fmt.Printf("%s %s %s\n", hex.EncodeToString(left), hex.EncodeToString(right), hex.EncodeToString(dst))
// The index to write this value to. This depends in the row permutation. // The index to write this value to. This depends in the row permutation.
rowI := perm[a*2+b] * 2 * K.valueLen rowI := perm[a*2+b] * 2 * K.valueLen
...@@ -149,11 +156,9 @@ func (c *Circuit) Evaluate(x []byte) ([]byte, error) { ...@@ -149,11 +156,9 @@ func (c *Circuit) Evaluate(x []byte) ([]byte, error) {
for j := 0; j < 4; j++ { for j := 0; j < 4; j++ {
left := c.G(K.getValue(0, Li[i]), K.getValue(0, Ri[i]), Oi[i]) left := c.G(K.getValue(0, Li[i]), K.getValue(0, Ri[i]), Oi[i])
right := c.F.getValue(i, j) right := c.F.getValue(i, j)
dst := util.XOR(left, right) dst := util.XOR(left, right)
if bytes.Equal(zeroes, dst[K.valueLen:]) { if bytes.Equal(zeroes, dst[K.valueLen:]) {
// fmt.Printf("%s %s %s\n", hex.EncodeToString(left), hex.EncodeToString(right), hex.EncodeToString(dst))
K.setValue(0, i+c.numInputs, dst[:K.valueLen]) K.setValue(0, i+c.numInputs, dst[:K.valueLen])
success = true success = true
break break
...@@ -163,63 +168,10 @@ func (c *Circuit) Evaluate(x []byte) ([]byte, error) { ...@@ -163,63 +168,10 @@ func (c *Circuit) Evaluate(x []byte) ([]byte, error) {
return nil, fmt.Errorf("failed to evaluate circuit layer %d", i) return nil, fmt.Errorf("failed to evaluate circuit layer %d", i)
} }
} }
return K.getValue(0, c.numWires-1), nil return K.getValue(0, c.numWires-1), nil
} }
// func Decode(d [][]byte, Z []byte) bool {
// if bytes.Equal(d[0], Z) {
// return false
// } else if bytes.Equal(d[1], Z) {
// return true
// }
// //TODO: Error handling
// return false
// }
////////////////////////////////////////////////////////////////
// Circuit
////////////////////////////////////////////////////////////////
// type Circuit struct {
// wires []int
// gates []Gate
// outputWire int
// numInput int
// }
// func NewCircuit(numInput int) *Circuit {
// return &Circuit{
// wires: make([]int, numInput),
// gates: make([]Gate, 0),
// numInput: numInput,
// }
// }
// func (c *Circuit) getInputWires(i int) (Li, Ri int) {
// return i, c.numInput + i
// }
// func (c *Circuit) AddGate(Li, Ri int, f func(a, b int) int) *Gate {
// c.wires = append(c.wires, 0)
// g := newGate(Li, Ri, len(c.wires)-1, f)
// c.gates = append(c.gates, *g)
// return g
// }
// type Gate struct {
// Li, Ri int
// i int
// truthfunc func(a, b int) int
// }
// func newGate(Li, Ri, i int, f func(a, b int) int) (g *Gate) {
// g.Li = Li
// g.Ri = Ri
// g.i = i
// g.truthfunc = f
// return
// }
//////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////
// GATES // GATES
//////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////
......
...@@ -16,7 +16,7 @@ func TestGBC(t *testing.T) { ...@@ -16,7 +16,7 @@ func TestGBC(t *testing.T) {
res := c.garbleGate(K, 0, 1, 2, ORGate) res := c.garbleGate(K, 0, 1, 2, ORGate)
count := 0 count := 0
for i := 0; i < 4; i++ { for i := 0; i < 4; i++ {
index := i * K.valueLen index := i * 2 * K.valueLen
if bytes.Equal(res[index:index+K.valueLen], K.getValue(2, 1)) { if bytes.Equal(res[index:index+K.valueLen], K.getValue(2, 1)) {
count++ count++
} }
...@@ -28,7 +28,7 @@ func TestGBC(t *testing.T) { ...@@ -28,7 +28,7 @@ func TestGBC(t *testing.T) {
res = c.garbleGate(K, 0, 1, 2, ANDGate) res = c.garbleGate(K, 0, 1, 2, ANDGate)
count = 0 count = 0
for i := 0; i < 4; i++ { for i := 0; i < 4; i++ {
index := i * K.valueLen index := i * 2 * K.valueLen
if bytes.Equal(res[index:index+K.valueLen], K.getValue(2, 1)) { if bytes.Equal(res[index:index+K.valueLen], K.getValue(2, 1)) {
count++ count++
} }
......
package util package util
import ( import (
"crypto/rand" cRand "crypto/rand"
"encoding/binary" "encoding/binary"
"io" "io"
"math"
"math/big" "math/big"
mathRand "math/rand" mRand "math/rand"
) )
// Int2Bools return x as a slice containing the binary representation.
func Int2Bools(x, n int) []bool {
output := make([]bool, 0)
tmp := 0
for i := n - 1; i >= 0; i-- {
tmp = int(math.Exp2(float64(i)))
output = append(output, (x&tmp)/tmp == 1)
}
// Reverse: Needed in protocol
for i := 0; i < n/2; i++ {
temp := output[i]
output[i] = output[n-i-1]
output[n-i-1] = temp
}
return output
}
// XOR computes the byte-wise XOR of two byte slices. Slices must hve equal length.
func XOR(a, b []byte) (dst []byte) {
if len(a) != len(b) {
panic("length must be equal")
}
dst = make([]byte, len(a))
for i := range a {
dst[i] = a[i] ^ b[i]
}
return
}
var one = big.NewInt(1) var one = big.NewInt(1)
// RandInt returns a random integer in the range [1, n). // RandInt returns a random integer in the range [1, n).
...@@ -15,7 +47,7 @@ func RandInt(random io.Reader, n *big.Int) (r *big.Int, err error) { ...@@ -15,7 +47,7 @@ func RandInt(random io.Reader, n *big.Int) (r *big.Int, err error) {
tmp := new(big.Int).Set(n) tmp := new(big.Int).Set(n)
tmp.Sub(tmp, one) tmp.Sub(tmp, one)
r, err = rand.Int(random, tmp) r, err = cRand.Int(random, tmp)
if err != nil { if err != nil {
return return
} }
...@@ -24,14 +56,14 @@ func RandInt(random io.Reader, n *big.Int) (r *big.Int, err error) { ...@@ -24,14 +56,14 @@ func RandInt(random io.Reader, n *big.Int) (r *big.Int, err error) {
return return
} }
type cryptoSource [8]byte
func Perm(n int) []int { func Perm(n int) []int {
return mathRand.New(&cryptoSource{}).Perm(n) return mRand.New(&cryptoSource{}).Perm(n)
} }
type cryptoSource [8]byte
func (s *cryptoSource) Int63() int64 { func (s *cryptoSource) Int63() int64 {
_, err := rand.Read(s[:]) _, err := cRand.Read(s[:])
if err != nil { if err != nil {
panic(err) panic(err)
} }
......
package util
import (
"math"
)
// Int2Bools return x as a slice containing the binary representation.
func Int2Bools(x, n int) []bool {
output := make([]bool, 0)
tmp := 0
for i := n - 1; i >= 0; i-- {
tmp = int(math.Exp2(float64(i)))
output = append(output, (x&tmp)/tmp == 1)
}
// Reverse: Needed in protocol
for i := 0; i < n/2; i++ {
temp := output[i]
output[i] = output[n-i-1]
output[n-i-1] = temp
}
return output
}
// XOR computes the byte-wise XOR of two byte slices. Slices must hve equal length.
func XOR(a, b []byte) (dst []byte) {
if len(a) != len(b) {
panic("length must be equal")
}
dst = make([]byte, len(a))
for i := range a {
dst[i] = a[i] ^ b[i]
}
return
}