Commit 1d8c8818 authored by Anders Jensen Løvig's avatar Anders Jensen Løvig
Browse files

TestGBC works..... finally

parent 9dd8f7ad
...@@ -34,7 +34,11 @@ func NewParty(x int, inChan chan [][]byte, circuitChan chan garbled.Table, outCh ...@@ -34,7 +34,11 @@ func NewParty(x int, inChan chan [][]byte, circuitChan chan garbled.Table, outCh
} }
func (p *Party) GenerateCircuit() { func (p *Party) GenerateCircuit() {
p.Circuit = garbled.NewGarbBloodCircuit() var err error
p.Circuit, err = garbled.NewGarbBloodCircuit()
if err != nil {
panic(err)
}
} }
func (p *Party) OTSend() { func (p *Party) OTSend() {
...@@ -73,8 +77,8 @@ func RunProtocol(x, y int, outputMode int) (bool, error) { ...@@ -73,8 +77,8 @@ func RunProtocol(x, y int, outputMode int) (bool, error) {
inCh, circCh, outCh := make(chan [][]byte), make(chan garbled.Table), make(chan []byte) inCh, circCh, outCh := make(chan [][]byte), make(chan garbled.Table), make(chan []byte)
p1, p2 := oblivious.NewPartyPair() p1, p2 := oblivious.NewPartyPair()
Alice, Bob := NewParty(x, inCh, circCh, outCh, p1), NewParty(y, inCh, circCh, outCh, p2) Alice, Bob := NewParty(x, inCh, circCh, outCh, p1), NewParty(y, inCh, circCh, outCh, p2)
fmt.Printf("Running protocol with [%d,%d]\n", x,y) fmt.Printf("Running protocol with [%d,%d]\n", x, y)
go func(Bob *Party){ go func(Bob *Party) {
// 1. Garble // 1. Garble
Bob.GenerateCircuit() Bob.GenerateCircuit()
Bob.CircuitChan <- *Bob.Circuit.F Bob.CircuitChan <- *Bob.Circuit.F
...@@ -96,43 +100,43 @@ func RunProtocol(x, y int, outputMode int) (bool, error) { ...@@ -96,43 +100,43 @@ func RunProtocol(x, y int, outputMode int) (bool, error) {
} }
if outputMode != 0 { // Bob learns if outputMode != 0 { // Bob learns
// Alice sends Z // Alice sends Z
Bob.result = garbled.Decode(Bob.Circuit.D, <- Bob.outChan) Bob.result = garbled.Decode(Bob.Circuit.D, <-Bob.outChan)
fmt.Printf("Bob has received Z so he can learn\n") fmt.Printf("Bob has received Z so he can learn\n")
} }
}(Bob) }(Bob)
//go func (Alice *Party) { //go func (Alice *Party) {
// 1. Garble // 1. Garble
F := <-Alice.CircuitChan F := <-Alice.CircuitChan
Alice.Circuit = &garbled.GCircuit{F: &F} Alice.Circuit = &garbled.GCircuit{F: &F}
fmt.Printf("Alice has received F\n") fmt.Printf("Alice has received F\n")
// 2. Encode Bobs input // 2. Encode Bobs input
Alice.EncInput = <-Alice.InChan Alice.EncInput = <-Alice.InChan
// 3. Encode Alices input // 3. Encode Alices input
X, _ := Alice.OTReceive() X, _ := Alice.OTReceive()
Alice.EncInput = append(Alice.EncInput, X...) Alice.EncInput = append(Alice.EncInput, X...)
fmt.Printf("Alice is finished garbling her input\n") fmt.Printf("Alice is finished garbling her input\n")
// 4. Evaluation // 4. Evaluation
Z, eval_err := garbled.Eval(Alice.Circuit.F,Alice.EncInput) Z, eval_err := garbled.Eval(Alice.Circuit.F, Alice.EncInput)
if eval_err != nil { if eval_err != nil {
return false, eval_err return false, eval_err
} }
fmt.Printf("Alice has evaluated the circuit\n") fmt.Printf("Alice has evaluated the circuit\n")
// 5. Output // 5. Output
if outputMode != 1 { // Alice learns if outputMode != 1 { // Alice learns
// Bob sends d // Bob sends d
Alice.result = garbled.Decode(<-Alice.InChan, Z) Alice.result = garbled.Decode(<-Alice.InChan, Z)
fmt.Printf("Alice has learned the output\n") fmt.Printf("Alice has learned the output\n")
} }
if outputMode != 0 { // Bob learns if outputMode != 0 { // Bob learns
// Alice sends Z // Alice sends Z
Alice.outChan <- Z Alice.outChan <- Z
fmt.Printf("Alice has sent Z to bob\n") fmt.Printf("Alice has sent Z to bob\n")
} }
//}(Alice) //}(Alice)
if outputMode != 1 { if outputMode != 1 {
return Alice.result, nil return Alice.result, nil
} else { } else {
return Bob.result, nil return Bob.result, nil
......
...@@ -2,57 +2,105 @@ package garbled ...@@ -2,57 +2,105 @@ package garbled
import ( import (
"bytes" "bytes"
cryptoUtil "crycomp/internal/crypto/util"
"crycomp/internal/util" "crycomp/internal/util"
"crypto"
"crypto/rand" "crypto/rand"
_ "crypto/sha256" "crypto/sha256"
"errors"
"fmt" "fmt"
m "math/rand"
) )
// Table contains the random or encoded strings of bit-length k.
type Table struct { type Table struct {
// data contains the raw byte data this table holds.
data []byte data []byte
wires, numValues, numBytes int // rows is the number of rows in data (not length).
// Normally this is the numbe of wires the table was created for.
rows int
// cols is the number of strings in each row. E.g the two input
// keys (K_0^i, K_1^i)
cols int
// valueLen contains the string length (of each value). This is always k/8.
valueLen int
} }
func NewTable(wires, numValues, k int) *Table { // NewTable returns a new garbled table with n rows and each row containing numValues columns.
t := Table{ // Each value contains k/8 bytes. k must be a multiple of 8.
data : make([]byte, wires*numValues*k/8), func NewTable(n, numValues, k int) (t *Table, err error) {
wires : wires, if k%8 != 0 {
numValues : numValues, return nil, errors.New("k must be multiple of 8")
numBytes : k/8,
} }
return &t t = &Table{
data: make([]byte, n*numValues*k/8),
rows: n,
cols: numValues,
valueLen: k / 8,
}
return
} }
func (t *Table) SetValue(data []byte, i, s int) { // index returns the data index for row i and column j.
valStart := i*t.numBytes+s*t.numBytes func (t *Table) index(r, c int) int {
copy(t.data[valStart : valStart+t.numBytes], data) return r*t.cols*t.valueLen + c*t.valueLen
} }
func (t *Table) GetValue(i,s int) []byte { // // SetValue sets the byte string value at row i and column j.
valStart := i*t.numBytes+s*t.numBytes // func (t *Table) SetValue(r, c int, data []byte) error {
return t.data[valStart : valStart+t.numBytes] // if len(data) != t.valueLen {
// return errors.New("data have invalid length")
// }
// index := t.index(r, c)
// copy(t.data[index:index+t.valueLen], data)
// return nil
// }
// GetValue returns a byte slice containing the value at row i and column j.
func (t *Table) GetValue(r, c int) []byte {
index := t.index(r, c)
return t.data[index : index+t.valueLen]
} }
func (t *Table) SetRow(data [][]byte, i int) { func (t *Table) SetRow(r int, data []byte) {
for j, val := range data { if len(data) != t.cols*t.valueLen {
t.SetValue(val,i,j) panic("data have invalid length")
} }
index := t.index(r, 0)
copy(t.data[index:index+t.cols*t.valueLen], data)
} }
func (t *Table) GetRow(i int) [][]byte { func (t *Table) GetRow(r int) []byte {
values := make([][]byte, 0) index := t.index(r, 0)
for j := 0; j < t.numValues; j++ { return t.data[index : index+t.cols*t.valueLen]
values = append(values,t.GetValue(i,j))
}
return values
} }
// func (t *Table) SetValueOld(data []byte, i, s int) {
// valStart := i*t.valueLen + s*t.valueLen
// copy(t.data[valStart:valStart+t.valueLen], data)
// }
// func (t *Table) GetValueOld(i, s int) []byte {
// valStart := i*t.valueLen + s*t.valueLen
// return t.data[valStart : valStart+t.valueLen]
// }
// func (t *Table) SetRowOld(data [][]byte, i int) {
// for j, val := range data {
// t.SetValueOld(val, i, j)
// }
// }
// func (t *Table) GetRowOld(i int) [][]byte {
// values := make([][]byte, 0)
// for j := 0; j < t.cols; j++ {
// values = append(values, t.GetValueOld(i, j))
// }
// return values
// }
func (t *Table) getKFirstValues(numInputWire int) [][][]byte { func (t *Table) getKFirstValues(numInputWire int) [][][]byte {
values := make([][][]byte, 0) values := make([][][]byte, 0)
for i := 0; i < numInputWire; i++ { for i := 0; i < numInputWire; i++ {
values = append(values,t.GetRow(i)) // values = append(values, t.GetRowOld(i))
} }
return values return values
} }
...@@ -67,61 +115,80 @@ type GCircuit struct { ...@@ -67,61 +115,80 @@ type GCircuit struct {
D [][]byte D [][]byte
} }
type Protocol struct {
G func(a, b []byte, i int) []byte
}
func NewGarbBloodCircuit() (c *GCircuit, err error) {
p := Protocol{G}
func NewGarbBloodCircuit() (circuit *GCircuit) {
numInput := 6 numInput := 6
numWires := 11 numWires := 11 // Includes numInput
k_bits := 128 k_bits := 128
// (1)
K := NewTable(numWires,2,k_bits) // 1. Create two random strings for each wire.
K, err := NewTable(numWires, 2, k_bits)
if err != nil {
return
}
K.randomizeTable() K.randomizeTable()
// (2) // 2. Create a garbled table for all C values (4 for each gate)
C := NewTable(numWires-numInput,4,k_bits*2) C, err := NewTable(numWires-numInput, 4, k_bits*2)
for i := 0; i < 3; i++ { //1st layer (consisting of not's and or's ) if err != nil {
//INput wires order: xs, xa, xb, ys, ya, yb (hopefully) return
C.SetRow(GBC(K, i,i+3, i+numInput, ORGateWithNot, G),i)
} }
C.SetRow(GBC(K,numInput,numInput+1,numInput+3, ANDGate,G),4)
C.SetRow(GBC(K,numInput+3,numInput+2,numInput+4, ANDGate, G),5)
return &GCircuit{ // // 1st layer are OR-gates with b input being NOT
F : C, C.SetRow(0, p.garbleGate(K, 0, 3, 6, ORGateWithNot))
D : K.GetRow(numWires), C.SetRow(1, p.garbleGate(K, 1, 4, 7, ORGateWithNot))
E : K.getKFirstValues(numInput), C.SetRow(2, p.garbleGate(K, 2, 5, 8, ORGateWithNot))
C.SetRow(3, p.garbleGate(K, 6, 7, 9, ANDGate))
C.SetRow(4, p.garbleGate(K, 8, 9, 10, ANDGate))
c = &GCircuit{
F: C,
E: K.getKFirstValues(numInput),
// D: K.GetRowOld(numWires),
} }
return
} }
//Creates a row (C_0^i, C_1^i, C_2^i ,C_3^i) for the garbled table, where
func GBC (K *Table, Li,Ri,i int, c func(a,b int) int, g func(A,B []byte, i int) []byte) (CRow [][]byte) { //Creates a row (C_0^i, C_1^i, C_2^i ,C_3^i) for the garbled table, where
CRow = make([][]byte, 0) func (p *Protocol) garbleGate(K *Table, Li, Ri, out int, c func(a, b int) int) (cRow []byte) {
CVals := make([][]byte, 4) // 4 x 256 bits cRow = make([]byte, K.valueLen*4)
// Random permutation
// (a) perm := cryptoUtil.Perm(4)
for a := range []int{0,1} {
for b := range []int{0,1} { // for (a,b) in {0,1} x {0,1}
left := g(K.GetValue(Li,a),K.GetValue(Ri,b),i) for a := 0; a <= 1; a++ {
right := make([]byte,16) for b := 0; b <= 1; b++ {
copy(right, K.GetValue(i,c(a,b))) left := p.G(K.GetValue(Li, a), K.GetValue(Ri, b), out)
right = append(right, make([]byte,16)...) right := make([]byte, K.valueLen*2)
CVals[a*2+b] = util.XOR(left, right) copy(right[:K.valueLen], K.GetValue(out, c(a, b)))
// right = append(right, make([]byte, 16)...)
// The index to write this value to. This depends in the row permutation.
rowI := perm[a*2+b] * K.valueLen
util.XOR(left, left, right) // XOR with left as destination.
copy(cRow[rowI:rowI+K.valueLen], left)
// cRow[rowI : rowI+K.valueLen] =
// cVals[a*2+b] = util.XOR(left, right)
} }
} }
// (b)
//TODO: Fix permutation
for i:= range m.Perm(4) { //Generate random order
CRow = append(CRow, CVals[i])
}
return return
} }
func G (A,B []byte, i int) []byte { func G(A, B []byte, i int) []byte {
hash := crypto.SHA256.New() hash := sha256.New()
hash.Write(A) hash.Write(A)
hash.Write(append(B,byte(i))) hash.Write(append(B, byte(i)))
return hash.Sum(nil) return hash.Sum(nil)
} }
func Enc(e [][][]byte, x []bool) (X [][]byte){ func Enc(e [][][]byte, x []bool) (X [][]byte) {
X = make([][]byte, 0) X = make([][]byte, 0)
for i, val := range x { for i, val := range x {
if val { if val {
...@@ -133,13 +200,13 @@ func Enc(e [][][]byte, x []bool) (X [][]byte){ ...@@ -133,13 +200,13 @@ func Enc(e [][][]byte, x []bool) (X [][]byte){
return return
} }
func Eval(C *Table, X [][]byte) ([]byte, error){ func Eval(C *Table, X [][]byte) ([]byte, error) {
var K = X var K = X
var res []byte var res []byte
for i := 0; i < 3; i++ { for i := 0; i < 3; i++ {
for j:= 0; j < 4; j++ { for j := 0; j < 4; j++ {
res = util.XOR(G(K[i],K[i+3],i), C.GetValue(i,j)) // res = util.XOR(G(K[i], K[i+3], i), C.GetValueOld(i, j))
if bytes.Equal(res[16:], make([]byte,16)) { if bytes.Equal(res[16:], make([]byte, 16)) {
K = append(K, res[:16]) K = append(K, res[:16])
break break
} }
...@@ -150,9 +217,9 @@ func Eval(C *Table, X [][]byte) ([]byte, error){ ...@@ -150,9 +217,9 @@ func Eval(C *Table, X [][]byte) ([]byte, error){
} }
func Decode(d [][]byte, Z []byte) bool { func Decode(d [][]byte, Z []byte) bool {
if bytes.Equal(d[0],Z) { if bytes.Equal(d[0], Z) {
return false return false
} else if bytes.Equal(d[1],Z) { } else if bytes.Equal(d[1], Z) {
return true return true
} }
//TODO: Error handling //TODO: Error handling
...@@ -164,37 +231,38 @@ func Decode(d [][]byte, Z []byte) bool { ...@@ -164,37 +231,38 @@ func Decode(d [][]byte, Z []byte) bool {
//////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////
type Circuit struct { type Circuit struct {
wires []int wires []int
gates []Gate gates []Gate
outputWire int outputWire int
numInput int numInput int
} }
func NewCircuit(numInput int) *Circuit { func NewCircuit(numInput int) *Circuit {
return &Circuit{ return &Circuit{
wires: make([]int, numInput), wires: make([]int, numInput),
gates: make([]Gate, 0), gates: make([]Gate, 0),
numInput : numInput, numInput: numInput,
} }
} }
func (c *Circuit) getInputWires(i int) (Li, Ri int) { func (c *Circuit) getInputWires(i int) (Li, Ri int) {
return i, c.numInput+i return i, c.numInput + i
} }
func (c *Circuit) AddGate(Li, Ri int, f func(a,b int) int) *Gate { func (c *Circuit) AddGate(Li, Ri int, f func(a, b int) int) *Gate {
c.wires = append(c.wires, 0) c.wires = append(c.wires, 0)
g := newGate(Li,Ri,len(c.wires)-1,f) g := newGate(Li, Ri, len(c.wires)-1, f)
c.gates = append(c.gates, *g) c.gates = append(c.gates, *g)
return g return g
} }
type Gate struct { type Gate struct {
Li, Ri int Li, Ri int
i int i int
truthfunc func(a,b int) int truthfunc func(a, b int) int
} }
func newGate(Li, Ri, i int, f func(a,b int) int) (g *Gate) { func newGate(Li, Ri, i int, f func(a, b int) int) (g *Gate) {
g.Li = Li g.Li = Li
g.Ri = Ri g.Ri = Ri
g.i = i g.i = i
...@@ -214,14 +282,14 @@ func ORGate(a, b int) int { ...@@ -214,14 +282,14 @@ func ORGate(a, b int) int {
} }
} }
func ORGateWithNot(a,b int) int { func ORGateWithNot(a, b int) int {
if ((1-a)+b) == 0 { if ((1 - b) + a) == 0 {
return 0 return 0
} else { } else {
return 1 return 1
} }
} }
func ANDGate(a,b int) int { func ANDGate(a, b int) int {
return a*b return a * b
} }
package garbled package garbled
import ( import (
"testing"
"bytes" "bytes"
"testing"
) )
func TestNewTable(t *testing.T) {
table, err := NewTable(1, 1, 8)
if err != nil {
t.Fatal(err)
}
if len(table.data) != 1 {
t.Errorf("Expected len(table.data) == 1, got %d", len(table.data))
}
table, err = NewTable(2, 1, 16)
if err != nil {
t.Fatal(err)
}
if len(table.data) != 4 {
t.Errorf("Expected len(table.data) == 4, got %d", len(table.data))
}
table, err = NewTable(2, 2, 16)
if err != nil {
t.Fatal(err)
}
if len(table.data) != 8 {
t.Errorf("Expected len(table.data) == 8, got %d", len(table.data))
}
}
func TestIndex(t *testing.T) {
table, _ := NewTable(2, 2, 16)
if i := table.index(0, 0); i != 0 {
t.Errorf("Expected table index(0, 0) == 0, got %d", i)
}
if i := table.index(0, 1); i != 2 {
t.Errorf("Expected table index(0, 1) == 2, got %d", i)
}
if i := table.index(1, 0); i != 4 {
t.Errorf("Expected table index(1, 0) == 4, got %d", i)
}
if i := table.index(1, 1); i != 6 {
t.Errorf("Expected table index(1, 1) == 8, got %d", i)
}
}
func AssertByteSlice(t *testing.T, actual, expected []byte) {
if !bytes.Equal(actual, expected) {
t.Errorf("Expected bytes to be equal")
}
}
func TestTableValue(t *testing.T) {
table, _ := NewTable(2, 2, 16)
table.SetRow(0, []byte{123, 81, 0, 0})
table.SetRow(1, []byte{1, 2, 3, 4})