Commit 1d8c8818 authored by Anders Jensen Løvig's avatar Anders Jensen Løvig
Browse files

TestGBC works..... finally

parent 9dd8f7ad
......@@ -34,7 +34,11 @@ func NewParty(x int, inChan chan [][]byte, circuitChan chan garbled.Table, outCh
}
func (p *Party) GenerateCircuit() {
p.Circuit = garbled.NewGarbBloodCircuit()
var err error
p.Circuit, err = garbled.NewGarbBloodCircuit()
if err != nil {
panic(err)
}
}
func (p *Party) OTSend() {
......@@ -73,8 +77,8 @@ func RunProtocol(x, y int, outputMode int) (bool, error) {
inCh, circCh, outCh := make(chan [][]byte), make(chan garbled.Table), make(chan []byte)
p1, p2 := oblivious.NewPartyPair()
Alice, Bob := NewParty(x, inCh, circCh, outCh, p1), NewParty(y, inCh, circCh, outCh, p2)
fmt.Printf("Running protocol with [%d,%d]\n", x,y)
go func(Bob *Party){
fmt.Printf("Running protocol with [%d,%d]\n", x, y)
go func(Bob *Party) {
// 1. Garble
Bob.GenerateCircuit()
Bob.CircuitChan <- *Bob.Circuit.F
......@@ -96,43 +100,43 @@ func RunProtocol(x, y int, outputMode int) (bool, error) {
}
if outputMode != 0 { // Bob learns
// Alice sends Z
Bob.result = garbled.Decode(Bob.Circuit.D, <- Bob.outChan)
Bob.result = garbled.Decode(Bob.Circuit.D, <-Bob.outChan)
fmt.Printf("Bob has received Z so he can learn\n")
}
}(Bob)
//go func (Alice *Party) {
// 1. Garble
F := <-Alice.CircuitChan
// 1. Garble
F := <-Alice.CircuitChan
Alice.Circuit = &garbled.GCircuit{F: &F}
fmt.Printf("Alice has received F\n")
// 2. Encode Bobs input
Alice.EncInput = <-Alice.InChan
// 3. Encode Alices input
X, _ := Alice.OTReceive()
Alice.EncInput = append(Alice.EncInput, X...)
fmt.Printf("Alice is finished garbling her input\n")
// 4. Evaluation
Z, eval_err := garbled.Eval(Alice.Circuit.F,Alice.EncInput)
if eval_err != nil {
return false, eval_err
}
fmt.Printf("Alice has evaluated the circuit\n")
// 5. Output
if outputMode != 1 { // Alice learns
// Bob sends d
Alice.result = garbled.Decode(<-Alice.InChan, Z)
fmt.Printf("Alice has learned the output\n")
}
if outputMode != 0 { // Bob learns
// Alice sends Z
Alice.outChan <- Z
fmt.Printf("Alice has sent Z to bob\n")
}
Alice.Circuit = &garbled.GCircuit{F: &F}
fmt.Printf("Alice has received F\n")
// 2. Encode Bobs input
Alice.EncInput = <-Alice.InChan
// 3. Encode Alices input
X, _ := Alice.OTReceive()
Alice.EncInput = append(Alice.EncInput, X...)
fmt.Printf("Alice is finished garbling her input\n")
// 4. Evaluation
Z, eval_err := garbled.Eval(Alice.Circuit.F, Alice.EncInput)
if eval_err != nil {
return false, eval_err
}
fmt.Printf("Alice has evaluated the circuit\n")
// 5. Output
if outputMode != 1 { // Alice learns
// Bob sends d
Alice.result = garbled.Decode(<-Alice.InChan, Z)
fmt.Printf("Alice has learned the output\n")
}
if outputMode != 0 { // Bob learns
// Alice sends Z
Alice.outChan <- Z
fmt.Printf("Alice has sent Z to bob\n")
}
//}(Alice)
if outputMode != 1 {
if outputMode != 1 {
return Alice.result, nil
} else {
return Bob.result, nil
......
......@@ -2,57 +2,105 @@ package garbled
import (
"bytes"
cryptoUtil "crycomp/internal/crypto/util"
"crycomp/internal/util"
"crypto"
"crypto/rand"
_ "crypto/sha256"
"crypto/sha256"
"errors"
"fmt"
m "math/rand"
)
// Table contains the random or encoded strings of bit-length k.
type Table struct {
// data contains the raw byte data this table holds.
data []byte
wires, numValues, numBytes int
// rows is the number of rows in data (not length).
// Normally this is the numbe of wires the table was created for.
rows int
// cols is the number of strings in each row. E.g the two input
// keys (K_0^i, K_1^i)
cols int
// valueLen contains the string length (of each value). This is always k/8.
valueLen int
}
func NewTable(wires, numValues, k int) *Table {
t := Table{
data : make([]byte, wires*numValues*k/8),
wires : wires,
numValues : numValues,
numBytes : k/8,
// NewTable returns a new garbled table with n rows and each row containing numValues columns.
// Each value contains k/8 bytes. k must be a multiple of 8.
func NewTable(n, numValues, k int) (t *Table, err error) {
if k%8 != 0 {
return nil, errors.New("k must be multiple of 8")
}
return &t
t = &Table{
data: make([]byte, n*numValues*k/8),
rows: n,
cols: numValues,
valueLen: k / 8,
}
return
}
func (t *Table) SetValue(data []byte, i, s int) {
valStart := i*t.numBytes+s*t.numBytes
copy(t.data[valStart : valStart+t.numBytes], data)
// index returns the data index for row i and column j.
func (t *Table) index(r, c int) int {
return r*t.cols*t.valueLen + c*t.valueLen
}
func (t *Table) GetValue(i,s int) []byte {
valStart := i*t.numBytes+s*t.numBytes
return t.data[valStart : valStart+t.numBytes]
// // SetValue sets the byte string value at row i and column j.
// func (t *Table) SetValue(r, c int, data []byte) error {
// if len(data) != t.valueLen {
// return errors.New("data have invalid length")
// }
// index := t.index(r, c)
// copy(t.data[index:index+t.valueLen], data)
// return nil
// }
// GetValue returns a byte slice containing the value at row i and column j.
func (t *Table) GetValue(r, c int) []byte {
index := t.index(r, c)
return t.data[index : index+t.valueLen]
}
func (t *Table) SetRow(data [][]byte, i int) {
for j, val := range data {
t.SetValue(val,i,j)
func (t *Table) SetRow(r int, data []byte) {
if len(data) != t.cols*t.valueLen {
panic("data have invalid length")
}
index := t.index(r, 0)
copy(t.data[index:index+t.cols*t.valueLen], data)
}
func (t *Table) GetRow(i int) [][]byte {
values := make([][]byte, 0)
for j := 0; j < t.numValues; j++ {
values = append(values,t.GetValue(i,j))
}
return values
func (t *Table) GetRow(r int) []byte {
index := t.index(r, 0)
return t.data[index : index+t.cols*t.valueLen]
}
// func (t *Table) SetValueOld(data []byte, i, s int) {
// valStart := i*t.valueLen + s*t.valueLen
// copy(t.data[valStart:valStart+t.valueLen], data)
// }
// func (t *Table) GetValueOld(i, s int) []byte {
// valStart := i*t.valueLen + s*t.valueLen
// return t.data[valStart : valStart+t.valueLen]
// }
// func (t *Table) SetRowOld(data [][]byte, i int) {
// for j, val := range data {
// t.SetValueOld(val, i, j)
// }
// }
// func (t *Table) GetRowOld(i int) [][]byte {
// values := make([][]byte, 0)
// for j := 0; j < t.cols; j++ {
// values = append(values, t.GetValueOld(i, j))
// }
// return values
// }
func (t *Table) getKFirstValues(numInputWire int) [][][]byte {
values := make([][][]byte, 0)
for i := 0; i < numInputWire; i++ {
values = append(values,t.GetRow(i))
// values = append(values, t.GetRowOld(i))
}
return values
}
......@@ -67,61 +115,80 @@ type GCircuit struct {
D [][]byte
}
type Protocol struct {
G func(a, b []byte, i int) []byte
}
func NewGarbBloodCircuit() (c *GCircuit, err error) {
p := Protocol{G}
func NewGarbBloodCircuit() (circuit *GCircuit) {
numInput := 6
numWires := 11
numWires := 11 // Includes numInput
k_bits := 128
// (1)
K := NewTable(numWires,2,k_bits)
// 1. Create two random strings for each wire.
K, err := NewTable(numWires, 2, k_bits)
if err != nil {
return
}
K.randomizeTable()
// (2)
C := NewTable(numWires-numInput,4,k_bits*2)
for i := 0; i < 3; i++ { //1st layer (consisting of not's and or's )
//INput wires order: xs, xa, xb, ys, ya, yb (hopefully)
C.SetRow(GBC(K, i,i+3, i+numInput, ORGateWithNot, G),i)
// 2. Create a garbled table for all C values (4 for each gate)
C, err := NewTable(numWires-numInput, 4, k_bits*2)
if err != nil {
return
}
C.SetRow(GBC(K,numInput,numInput+1,numInput+3, ANDGate,G),4)
C.SetRow(GBC(K,numInput+3,numInput+2,numInput+4, ANDGate, G),5)
return &GCircuit{
F : C,
D : K.GetRow(numWires),
E : K.getKFirstValues(numInput),
// // 1st layer are OR-gates with b input being NOT
C.SetRow(0, p.garbleGate(K, 0, 3, 6, ORGateWithNot))
C.SetRow(1, p.garbleGate(K, 1, 4, 7, ORGateWithNot))
C.SetRow(2, p.garbleGate(K, 2, 5, 8, ORGateWithNot))
C.SetRow(3, p.garbleGate(K, 6, 7, 9, ANDGate))
C.SetRow(4, p.garbleGate(K, 8, 9, 10, ANDGate))
c = &GCircuit{
F: C,
E: K.getKFirstValues(numInput),
// D: K.GetRowOld(numWires),
}
return
}
//Creates a row (C_0^i, C_1^i, C_2^i ,C_3^i) for the garbled table, where
func GBC (K *Table, Li,Ri,i int, c func(a,b int) int, g func(A,B []byte, i int) []byte) (CRow [][]byte) {
CRow = make([][]byte, 0)
CVals := make([][]byte, 4) // 4 x 256 bits
// (a)
for a := range []int{0,1} {
for b := range []int{0,1} {
left := g(K.GetValue(Li,a),K.GetValue(Ri,b),i)
right := make([]byte,16)
copy(right, K.GetValue(i,c(a,b)))
right = append(right, make([]byte,16)...)
CVals[a*2+b] = util.XOR(left, right)
//Creates a row (C_0^i, C_1^i, C_2^i ,C_3^i) for the garbled table, where
func (p *Protocol) garbleGate(K *Table, Li, Ri, out int, c func(a, b int) int) (cRow []byte) {
cRow = make([]byte, K.valueLen*4)
// Random permutation
perm := cryptoUtil.Perm(4)
// for (a,b) in {0,1} x {0,1}
for a := 0; a <= 1; a++ {
for b := 0; b <= 1; b++ {
left := p.G(K.GetValue(Li, a), K.GetValue(Ri, b), out)
right := make([]byte, K.valueLen*2)
copy(right[:K.valueLen], K.GetValue(out, c(a, b)))
// right = append(right, make([]byte, 16)...)
// The index to write this value to. This depends in the row permutation.
rowI := perm[a*2+b] * K.valueLen
util.XOR(left, left, right) // XOR with left as destination.
copy(cRow[rowI:rowI+K.valueLen], left)
// cRow[rowI : rowI+K.valueLen] =
// cVals[a*2+b] = util.XOR(left, right)
}
}
// (b)
//TODO: Fix permutation
for i:= range m.Perm(4) { //Generate random order
CRow = append(CRow, CVals[i])
}
return
}
func G (A,B []byte, i int) []byte {
hash := crypto.SHA256.New()
func G(A, B []byte, i int) []byte {
hash := sha256.New()
hash.Write(A)
hash.Write(append(B,byte(i)))
hash.Write(append(B, byte(i)))
return hash.Sum(nil)
}
func Enc(e [][][]byte, x []bool) (X [][]byte){
func Enc(e [][][]byte, x []bool) (X [][]byte) {
X = make([][]byte, 0)
for i, val := range x {
if val {
......@@ -133,13 +200,13 @@ func Enc(e [][][]byte, x []bool) (X [][]byte){
return
}
func Eval(C *Table, X [][]byte) ([]byte, error){
func Eval(C *Table, X [][]byte) ([]byte, error) {
var K = X
var res []byte
for i := 0; i < 3; i++ {
for j:= 0; j < 4; j++ {
res = util.XOR(G(K[i],K[i+3],i), C.GetValue(i,j))
if bytes.Equal(res[16:], make([]byte,16)) {
for j := 0; j < 4; j++ {
// res = util.XOR(G(K[i], K[i+3], i), C.GetValueOld(i, j))
if bytes.Equal(res[16:], make([]byte, 16)) {
K = append(K, res[:16])
break
}
......@@ -150,9 +217,9 @@ func Eval(C *Table, X [][]byte) ([]byte, error){
}
func Decode(d [][]byte, Z []byte) bool {
if bytes.Equal(d[0],Z) {
if bytes.Equal(d[0], Z) {
return false
} else if bytes.Equal(d[1],Z) {
} else if bytes.Equal(d[1], Z) {
return true
}
//TODO: Error handling
......@@ -164,37 +231,38 @@ func Decode(d [][]byte, Z []byte) bool {
////////////////////////////////////////////////////////////////
type Circuit struct {
wires []int
gates []Gate
wires []int
gates []Gate
outputWire int
numInput int
numInput int
}
func NewCircuit(numInput int) *Circuit {
return &Circuit{
wires: make([]int, numInput),
gates: make([]Gate, 0),
numInput : numInput,
wires: make([]int, numInput),
gates: make([]Gate, 0),
numInput: numInput,
}
}
func (c *Circuit) getInputWires(i int) (Li, Ri int) {
return i, c.numInput+i
return i, c.numInput + i
}
func (c *Circuit) AddGate(Li, Ri int, f func(a,b int) int) *Gate {
func (c *Circuit) AddGate(Li, Ri int, f func(a, b int) int) *Gate {
c.wires = append(c.wires, 0)
g := newGate(Li,Ri,len(c.wires)-1,f)
g := newGate(Li, Ri, len(c.wires)-1, f)
c.gates = append(c.gates, *g)
return g
}
type Gate struct {
Li, Ri int
i int
truthfunc func(a,b int) int
Li, Ri int
i int
truthfunc func(a, b int) int
}
func newGate(Li, Ri, i int, f func(a,b int) int) (g *Gate) {
func newGate(Li, Ri, i int, f func(a, b int) int) (g *Gate) {
g.Li = Li
g.Ri = Ri
g.i = i
......@@ -214,14 +282,14 @@ func ORGate(a, b int) int {
}
}
func ORGateWithNot(a,b int) int {
if ((1-a)+b) == 0 {
func ORGateWithNot(a, b int) int {
if ((1 - b) + a) == 0 {
return 0
} else {
return 1
}
}
func ANDGate(a,b int) int {
return a*b
func ANDGate(a, b int) int {
return a * b
}
package garbled
import (
"testing"
"bytes"
"testing"
)
func TestNewTable(t *testing.T) {
table, err := NewTable(1, 1, 8)
if err != nil {
t.Fatal(err)
}
if len(table.data) != 1 {
t.Errorf("Expected len(table.data) == 1, got %d", len(table.data))
}
table, err = NewTable(2, 1, 16)
if err != nil {
t.Fatal(err)
}
if len(table.data) != 4 {
t.Errorf("Expected len(table.data) == 4, got %d", len(table.data))
}
table, err = NewTable(2, 2, 16)
if err != nil {
t.Fatal(err)
}
if len(table.data) != 8 {
t.Errorf("Expected len(table.data) == 8, got %d", len(table.data))
}
}
func TestIndex(t *testing.T) {
table, _ := NewTable(2, 2, 16)
if i := table.index(0, 0); i != 0 {
t.Errorf("Expected table index(0, 0) == 0, got %d", i)
}
if i := table.index(0, 1); i != 2 {
t.Errorf("Expected table index(0, 1) == 2, got %d", i)
}
if i := table.index(1, 0); i != 4 {
t.Errorf("Expected table index(1, 0) == 4, got %d", i)
}
if i := table.index(1, 1); i != 6 {
t.Errorf("Expected table index(1, 1) == 8, got %d", i)
}
}
func AssertByteSlice(t *testing.T, actual, expected []byte) {
if !bytes.Equal(actual, expected) {
t.Errorf("Expected bytes to be equal")
}
}
func TestTableValue(t *testing.T) {
table, _ := NewTable(2, 2, 16)
table.SetRow(0, []byte{123, 81, 0, 0})
table.SetRow(1, []byte{1, 2, 3, 4})
AssertByteSlice(t, table.data, []byte{123, 81, 0, 0, 1, 2, 3, 4})
AssertByteSlice(t, table.GetValue(0, 0), []byte{123, 81})
AssertByteSlice(t, table.GetValue(0, 1), []byte{0, 0})
AssertByteSlice(t, table.GetRow(0), []byte{123, 81, 0, 0})
AssertByteSlice(t, table.GetValue(1, 0), []byte{1, 2})
AssertByteSlice(t, table.GetValue(1, 1), []byte{3, 4})
AssertByteSlice(t, table.GetRow(1), []byte{1, 2, 3, 4})
}
func TestGBC(t *testing.T) {
K := NewTable(3, 2, 128)
K, err := NewTable(3, 2, 16)
if err != nil {
t.Fatalf("Error: %s", err)
}
K.randomizeTable()
res := GBC(K, 0, 1, 2, ORGate, g)
print(res)
crctval := append(K.GetValue(2,1),make([]byte,16)...)
crctidx := make([]int,0)
for i := range res {
if bytes.Equal(res[i], crctval) {
crctidx = append(crctidx,i)
p := Protocol{testingG}
res := p.garbleGate(K, 0, 1, 2, ORGate)
count := 0
for i := 0; i < 4; i++ {
index := i * K.valueLen
if bytes.Equal(res[index:index+K.valueLen], K.GetValue(2, 1)) {
count++
}
}
if len(crctidx) != 3 {
if count != 3 {
t.Error("Function should return 3 true labels and 1 false label when using an OR Gate")
}
res = GBC(K, 0, 1, 2, ANDGate, g)
crctidx = make([]int,0)
for i := range res {
if bytes.Equal(res[i], crctval) {
crctidx = append(crctidx,i)
res = p.garbleGate(K, 0, 1, 2, ANDGate)
count = 0
for i := 0; i < 4; i++ {
index := i * K.valueLen
if bytes.Equal(res[index:index+K.valueLen], K.GetValue(2, 1)) {
count++
}
}
if len(crctidx) != 1 {
if count != 1 {
t.Error("Function should return 1 true labels and 3 false label when using an AND Gate")
}
}
func g(A, B []byte, i int) []byte {
func testingG(A, B []byte, i int) []byte {
return make([]byte, len(A)*2)
}
func TestDecode(t *testing.T) {
}
......@@ -2,8 +2,10 @@ package util
import (
"crypto/rand"
"encoding/binary"
"io"
"math/big"
mathRand "math/rand"
)
var one = big.NewInt(1)
......@@ -21,3 +23,21 @@ func RandInt(random io.Reader, n *big.Int) (r *big.Int, err error) {
return