Commit 40c9c061 authored by Anders Jensen Løvig's avatar Anders Jensen Løvig
Browse files

Almost done

parent 51765d10
...@@ -67,7 +67,7 @@ func NewParty(in chan []byte, out chan []byte, obliv *oblivious.Party) (p *Party ...@@ -67,7 +67,7 @@ func NewParty(in chan []byte, out chan []byte, obliv *oblivious.Party) (p *Party
return return
} }
func (alice *Party) RunAlice(x int) (err error) { func (alice *Party) RunAlice(x int) (result bool, err error) {
// 1. Garble // 1. Garble
// F, _ := garbled.NewTable(NumWires-NumInputs*2, 4, K*2) // F, _ := garbled.NewTable(NumWires-NumInputs*2, 4, K*2)
alice.circuit.F.SetData(<-alice.receive) alice.circuit.F.SetData(<-alice.receive)
...@@ -86,11 +86,17 @@ func (alice *Party) RunAlice(x int) (err error) { ...@@ -86,11 +86,17 @@ func (alice *Party) RunAlice(x int) (err error) {
// 4. Evaluation // 4. Evaluation
evalInput := append(X, Y...) evalInput := append(X, Y...)
// fmt.Println(len(evalInput)) Z, err := alice.circuit.Evaluate(evalInput)
_, err = alice.circuit.Evaluate(evalInput)
if err != nil { if err != nil {
return return
} }
fmt.Println("Alice: Circuit evaluated")
// 5b. Alice Output
alice.circuit.D, _ = garbled.NewTable(1, 2, K)
alice.circuit.D.SetData(<-alice.receive)
result, err = garbled.Decode(alice.circuit.D, Z)
fmt.Println("Alice: Decoded result")
return return
} }
...@@ -122,6 +128,11 @@ func (bob *Party) RunBob(y int) (err error) { ...@@ -122,6 +128,11 @@ func (bob *Party) RunBob(y int) (err error) {
} }
// 4. Evaluation // 4. Evaluation
// do nothing
// 5b. Alice output
fmt.Println("Bob: Sending d to Alice")
bob.send <- bob.circuit.D.GetData()
return return
} }
...@@ -173,8 +184,7 @@ func RunProtocol(x, y int, outputMode int) (z bool, err error) { ...@@ -173,8 +184,7 @@ func RunProtocol(x, y int, outputMode int) (z bool, err error) {
} }
}() }()
err = p.A.RunAlice(x) z, err = p.A.RunAlice(x)
return return
// go func(bob *Party) { // go func(bob *Party) {
......
...@@ -5,7 +5,7 @@ import ( ...@@ -5,7 +5,7 @@ import (
cryptoUtil "crycomp/internal/crypto/util" cryptoUtil "crycomp/internal/crypto/util"
"crycomp/internal/util" "crycomp/internal/util"
"crypto/sha256" "crypto/sha256"
"encoding/hex" "errors"
"fmt" "fmt"
) )
...@@ -86,10 +86,20 @@ func Encode(e *Table, x []bool) (X []byte) { ...@@ -86,10 +86,20 @@ func Encode(e *Table, x []bool) (X []byte) {
return return
} }
func Decode(d *Table, Z []byte) (bool, error) {
if bytes.Equal(d.getValue(0, 0), Z) {
return false, nil
} else if bytes.Equal(d.getValue(0, 1), Z) {
return true, nil
} else {
return false, errors.New("failed to decode")
}
}
//Creates a row (C_0^i, C_1^i, C_2^i ,C_3^i) for the garbled table, where //Creates a row (C_0^i, C_1^i, C_2^i ,C_3^i) for the garbled table, where
func (c *Circuit) garbleGate(K *Table, Li, Ri, out int, gateFun func(a, b int) int) []byte { func (c *Circuit) garbleGate(K *Table, Li, Ri, out int, gateFun func(a, b int) int) []byte {
cRow := make([]byte, K.valueLen*8) // because 4 values of double length cRow := make([]byte, 4*2*K.valueLen) // because 4 values of double length
perm := cryptoUtil.Perm(4) // Random permutation perm := cryptoUtil.Perm(4) // Random permutation
// for (a,b) in {0,1} x {0,1} // for (a,b) in {0,1} x {0,1}
for a := 0; a <= 1; a++ { for a := 0; a <= 1; a++ {
...@@ -100,14 +110,14 @@ func (c *Circuit) garbleGate(K *Table, Li, Ri, out int, gateFun func(a, b int) i ...@@ -100,14 +110,14 @@ func (c *Circuit) garbleGate(K *Table, Li, Ri, out int, gateFun func(a, b int) i
copy(right[:K.valueLen], K.getValue(out, gateFun(a, b))) copy(right[:K.valueLen], K.getValue(out, gateFun(a, b)))
dst := util.XOR(left, right) // XOR with left as destination. dst := util.XOR(left, right) // XOR with left as destination.
fmt.Printf("%s %s %s\n", hex.EncodeToString(left), hex.EncodeToString(right), hex.EncodeToString(dst)) // fmt.Printf("%s %s %s\n", hex.EncodeToString(left), hex.EncodeToString(right), hex.EncodeToString(dst))
// The index to write this value to. This depends in the row permutation. // The index to write this value to. This depends in the row permutation.
rowI := perm[a*2+b] * K.valueLen rowI := perm[a*2+b] * 2 * K.valueLen
copy(cRow[rowI:rowI+2*K.valueLen], dst)
copy(cRow[rowI:rowI+K.valueLen], dst)
} }
} }
return cRow return cRow
} }
...@@ -130,52 +140,41 @@ func (c *Circuit) Evaluate(x []byte) ([]byte, error) { ...@@ -130,52 +140,41 @@ func (c *Circuit) Evaluate(x []byte) ([]byte, error) {
Ri := []int{3, 4, 5, 7, 9} Ri := []int{3, 4, 5, 7, 9}
Oi := []int{6, 7, 8, 9, 10} Oi := []int{6, 7, 8, 9, 10}
// zeroes := make([]byte, K.valueLen) zeroes := make([]byte, K.valueLen)
fmt.Println(c.F.valueLen)
// For each circuit gate // For each circuit gate
for i := 0; i < c.F.rows; i++ { for i := 0; i < c.F.rows; i++ {
// For each C // For each C
success := false
for j := 0; j < 4; j++ { for j := 0; j < 4; j++ {
left := c.G(K.getValue(0, Li[i]), K.getValue(0, Ri[i]), Oi[i]) left := c.G(K.getValue(0, Li[i]), K.getValue(0, Ri[i]), Oi[i])
right := c.F.getValue(i, j) right := c.F.getValue(i, j)
dst := util.XOR(left, right) dst := util.XOR(left, right)
fmt.Printf("%s %s %s\n", hex.EncodeToString(left), hex.EncodeToString(right), hex.EncodeToString(dst))
// fmt.Printf(len(left)) if bytes.Equal(zeroes, dst[K.valueLen:]) {
// fmt.Printf("%s %s %s\n", hex.EncodeToString(left), hex.EncodeToString(right), hex.EncodeToString(dst))
// dst := make([]byte, 2*F.valueLen) K.setValue(0, i+c.numInputs, dst[:K.valueLen])
// util.XOR(dst, K.getValue()) success = true
break
}
}
if !success {
return nil, fmt.Errorf("failed to evaluate circuit layer %d", i)
} }
} }
return K.getValue(0, c.numWires-1), nil
// var K = x
// var res []byte
// for i := 0; i < 3; i++ {
// for j := 0; j < 4; j++ {
// // res = util.XOR(G(K[i], K[i+3], i), C.GetValueOld(i, j))
// if bytes.Equal(res[16:], make([]byte, 16)) {
// K = append(K, res[:16])
// break
// }
// }
// return nil, fmt.Errorf("Aborted while evaluating the garbled circuit, no match was found")
// }
// return K[len(K)-1], nil
return nil, nil
} }
func Decode(d [][]byte, Z []byte) bool { // func Decode(d [][]byte, Z []byte) bool {
if bytes.Equal(d[0], Z) { // if bytes.Equal(d[0], Z) {
return false // return false
} else if bytes.Equal(d[1], Z) { // } else if bytes.Equal(d[1], Z) {
return true // return true
} // }
//TODO: Error handling // //TODO: Error handling
return false // return false
} // }
//////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////
// Circuit // Circuit
......
...@@ -39,6 +39,14 @@ func (t *Table) index(r, c int) int { ...@@ -39,6 +39,14 @@ func (t *Table) index(r, c int) int {
return r*t.cols*t.valueLen + c*t.valueLen return r*t.cols*t.valueLen + c*t.valueLen
} }
func (t *Table) setValue(r, c int, data []byte) {
if len(data) != t.valueLen {
panic("data have invalid length")
}
index := t.index(r, c)
copy(t.data[index:index+t.valueLen], data)
}
// getValue returns a byte slice containing the value at row i and column j. // getValue returns a byte slice containing the value at row i and column j.
func (t *Table) getValue(r, c int) (data []byte) { func (t *Table) getValue(r, c int) (data []byte) {
index := t.index(r, c) index := t.index(r, c)
...@@ -62,9 +70,11 @@ func (t *Table) getRow(r int) (data []byte) { ...@@ -62,9 +70,11 @@ func (t *Table) getRow(r int) (data []byte) {
return return
} }
func (t *Table) getRows(r, count int) []byte { func (t *Table) getRows(r, count int) (data []byte) {
index := t.index(r, 0) index := t.index(r, 0)
return t.data[index : index+count*t.cols*t.valueLen] data = make([]byte, count*t.cols*t.valueLen)
copy(data, t.data[index:index+count*t.cols*t.valueLen])
return
} }
func (t *Table) SetData(data []byte) { func (t *Table) SetData(data []byte) {
......
...@@ -69,3 +69,13 @@ func TestTableValue(t *testing.T) { ...@@ -69,3 +69,13 @@ func TestTableValue(t *testing.T) {
AssertByteSlice(t, table.getRows(0, 2), []byte{123, 81, 0, 0, 1, 2, 3, 4}) AssertByteSlice(t, table.getRows(0, 2), []byte{123, 81, 0, 0, 1, 2, 3, 4})
AssertByteSlice(t, table.getRows(1, 1), []byte{1, 2, 3, 4}) AssertByteSlice(t, table.getRows(1, 1), []byte{1, 2, 3, 4})
} }
func TestGetData(t *testing.T) {
c, _ := NewCircuit(6, 11, 128)
_ = c.GarbleBloodCircuit()
data := c.F.GetData()
if !bytes.Equal(data, c.F.data) {
t.Errorf("Bytes not equal")
}
}
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment