Commit 2383257e authored by Anders Jensen Løvig's avatar Anders Jensen Løvig
Browse files

Refactor oblivious transfer

parent 1d8c8818
......@@ -4,27 +4,40 @@ import (
"crycomp/internal/blood"
"crycomp/internal/crypto/oblivious"
"fmt"
"strconv"
)
func runBob(y int, bob *oblivious.Party) error {
// Encode Bob's blood column as bytes.
data := make([][]byte, len(blood.Table))
for i := 0; i < len(blood.Table); i++ {
if blood.Table[i][y] {
data[i] = []byte{1}
} else {
data[i] = []byte{0}
}
}
return bob.Send(data)
}
// RunProtocol runs the protocol between receiving blood type x and donor blood
// type y.
func RunProtocol(x, y int) (bool, error) {
var ydata [][]byte
for i := 0; i < len(blood.Table); i++ {
ydata = append(ydata, strconv.AppendBool(make([]byte,0),blood.Table[i][y]))
}
p := oblivious.NewProtocol(x, ydata)
// We let Bob listen until Alice sends.
// This kinda simulates talking over a network.
p := oblivious.NewProtocol(len(blood.Table))
go func() {
p.Sender.Listen()
err := runBob(y, p.S)
if err != nil {
panic(err)
}
}()
data, err := p.Receiver.ObliviousTransfer(8)
result, _ := strconv.ParseBool(string(data))
return result, err
alice := p.R
data, err := alice.Receive(x)
if err != nil {
return false, err
}
return len(data) == 1 && data[0] == 1, nil
}
func main() {
......
......@@ -10,32 +10,86 @@ import (
)
////////////////////////////////////////////////////////////////
const NumInputs = 3 // How many inputs each player provide
const NumWires = 11
const K = 128
////////////////////////////////////////////////////////////////
type Protocol struct {
A, B *Party
oblivP *oblivious.Protocol
}
func NewProtocol() *Protocol {
connA2B := make(chan []byte)
connB2A := make(chan []byte)
obliv := oblivious.NewProtocol()
return &Protocol{
A: NewParty(connA2B, connB2A),
B: NewParty(connB2A, connA2B),
oblivP: New,
}
}
type Party struct {
x []bool
Circuit *garbled.GCircuit
InChan chan [][]byte
CircuitChan chan garbled.Table
EncInput [][]byte
// Channels for communication
send chan []byte
receive chan []byte
x []bool
Circuit *garbled.Circuit
Out chan []byte
TableChan chan garbled.Table
encodedX []byte
*oblivious.Party
result bool
outChan chan []byte
}
func NewParty(x int, inChan chan [][]byte, circuitChan chan garbled.Table, outChan chan []byte, party *oblivious.Party) (p *Party) {
func NewParty(in chan []byte, out chan []byte) (p *Party) {
return &Party{
x: util.Int2Bools(x, 3),
EncInput: make([][]byte, 6),
InChan: inChan,
outChan: outChan,
CircuitChan: circuitChan,
Party: party,
send: in,
receive: out,
}
}
func (alice *Party) RunAlice(x []bool) (err error) {
// 1. Garble
F, _ := garbled.NewTable(NumWires-NumInputs*2, 4, K*2)
F.SetData(<-alice.receive)
fmt.Println("Alice: Received F from Bob")
// 2. Encode Bob's input
fmt.Println("Alice: Received Y from Bob")
_ = <-alice.receive
return
}
func (bob *Party) RunBob(y []bool) (err error) {
// 1. Garble
bob.GenerateCircuit()
fmt.Println("Bob: Sending F to Alice")
bob.send <- bob.Circuit.F.GetData()
// 2. Encode Bob's input
Y := garbled.Encode(bob.Circuit.E[1], bob.x, len(bob.x))
fmt.Println("Bob: Sending Y to Alice")
bob.send <- Y
return
}
func (p *Party) GenerateCircuit() {
var err error
p.Circuit, err = garbled.NewGarbBloodCircuit()
p.Circuit, err = garbled.NewGarbBloodCircuit(NumInputs, NumWires, K)
if err != nil {
panic(err)
}
......@@ -43,7 +97,8 @@ func (p *Party) GenerateCircuit() {
func (p *Party) OTSend() {
for i := 0; i < 3; i++ {
p.Listen(p.Circuit.E[i+3])
// p.Listen(p.Circuit.E[i+3])
p.Listen(nil)
}
}
......@@ -73,74 +128,83 @@ func OTransfer(sender, receiver *Party) {
// RunProtocol runs the protocol between receiving blood type x and donor blood
// type y.
// outputMode: 0: Alice learns, 1: Bob learns else both learns
func RunProtocol(x, y int, outputMode int) (bool, error) {
inCh, circCh, outCh := make(chan [][]byte), make(chan garbled.Table), make(chan []byte)
p1, p2 := oblivious.NewPartyPair()
Alice, Bob := NewParty(x, inCh, circCh, outCh, p1), NewParty(y, inCh, circCh, outCh, p2)
fmt.Printf("Running protocol with [%d,%d]\n", x, y)
go func(Bob *Party) {
// 1. Garble
Bob.GenerateCircuit()
Bob.CircuitChan <- *Bob.Circuit.F
fmt.Printf("Bob has generated the circuit and sent F to alice\n")
// 2. Encode Bobs input
Bob.EncInput = garbled.Enc(Bob.Circuit.E[:3], Bob.x)
Bob.InChan <- Bob.EncInput
fmt.Printf("Bob has garbled his input and sent it to alice\n")
// 3. Encode Alices input
Bob.OTSend()
fmt.Printf("Bob has finished garbling alices input\n")
// 5. Output
if outputMode != 1 { // Alice learns
// Bob sends d
Bob.InChan <- Bob.Circuit.D
fmt.Printf("Bob sends D so Alice can learn\n")
func RunProtocol(x, y int, outputMode int) (z bool, err error) {
z = false
p := NewProtocol()
// Concurrently run Bob
go func() {
err := p.B.RunBob(util.Int2Bools(y, 3))
if err != nil {
panic(err) // TODO do not panic
}
if outputMode != 0 { // Bob learns
// Alice sends Z
Bob.result = garbled.Decode(Bob.Circuit.D, <-Bob.outChan)
fmt.Printf("Bob has received Z so he can learn\n")
}()
}
}(Bob)
err = p.A.RunAlice(util.Int2Bools(x, 3))
return
// go func(bob *Party) {
// 3. Encode Alice's input
// Bob.InChan <- Bob.EncInput
// fmt.Printf("Bob has garbled his input and sent it to alice\n")
// // 3. Encode Alices input
// Bob.OTSend()
// fmt.Printf("Bob has finished garbling alices input\n")
// // 5. Output
// if outputMode != 1 { // Alice learns
// // Bob sends d
// Bob.InChan <- Bob.Circuit.D
// fmt.Printf("Bob sends D so Alice can learn\n")
// }
// if outputMode != 0 { // Bob learns
// // Alice sends Z
// Bob.result = garbled.Decode(Bob.Circuit.D, <-Bob.outChan)
// fmt.Printf("Bob has received Z so he can learn\n")
// }
// }(bob)
//go func (Alice *Party) {
// 1. Garble
F := <-Alice.CircuitChan
Alice.Circuit = &garbled.GCircuit{F: &F}
fmt.Printf("Alice has received F\n")
// 2. Encode Bobs input
Alice.EncInput = <-Alice.InChan
// 3. Encode Alices input
X, _ := Alice.OTReceive()
Alice.EncInput = append(Alice.EncInput, X...)
fmt.Printf("Alice is finished garbling her input\n")
// 4. Evaluation
Z, eval_err := garbled.Eval(Alice.Circuit.F, Alice.EncInput)
if eval_err != nil {
return false, eval_err
}
fmt.Printf("Alice has evaluated the circuit\n")
// 5. Output
if outputMode != 1 { // Alice learns
// Bob sends d
Alice.result = garbled.Decode(<-Alice.InChan, Z)
fmt.Printf("Alice has learned the output\n")
}
if outputMode != 0 { // Bob learns
// Alice sends Z
Alice.outChan <- Z
fmt.Printf("Alice has sent Z to bob\n")
}
//}(Alice)
// F := <-alice.TableChan
if outputMode != 1 {
return Alice.result, nil
} else {
return Bob.result, nil
}
// alice.Circuit = &garbled.Circuit{F: &F}
// fmt.Printf("Alice has received F\n")
// // 2. Encode Bobs input
// alice.EncInput = <-alice.InChan
// // 3. Encode Alices input
// X, _ := alice.OTReceive()
// alice.EncInput = append(alice.EncInput, X...)
// fmt.Printf("Alice is finished garbling her input\n")
// // 4. Evaluation
// Z, eval_err := garbled.Eval(alice.Circuit.F, alice.EncInput)
// if eval_err != nil {
// return false, eval_err
// }
// fmt.Printf("Alice has evaluated the circuit\n")
// // 5. Output
// if outputMode != 1 { // Alice learns
// // Bob sends d
// alice.result = garbled.Decode(<-alice.InChan, Z)
// fmt.Printf("Alice has learned the output\n")
// }
// if outputMode != 0 { // Bob learns
// // Alice sends Z
// alice.outChan <- Z
// fmt.Printf("Alice has sent Z to bob\n")
// }
// //}(Alice)
// if outputMode != 1 {
// return alice.result, nil
// } else {
// return Bob.result, nil
// }
// return false, nil
}
func main() {
......
......@@ -44,23 +44,13 @@ func (t *Table) index(r, c int) int {
return r*t.cols*t.valueLen + c*t.valueLen
}
// // SetValue sets the byte string value at row i and column j.
// func (t *Table) SetValue(r, c int, data []byte) error {
// if len(data) != t.valueLen {
// return errors.New("data have invalid length")
// }
// index := t.index(r, c)
// copy(t.data[index:index+t.valueLen], data)
// return nil
// }
// GetValue returns a byte slice containing the value at row i and column j.
func (t *Table) GetValue(r, c int) []byte {
// getValue returns a byte slice containing the value at row i and column j.
func (t *Table) getValue(r, c int) []byte {
index := t.index(r, c)
return t.data[index : index+t.valueLen]
}
func (t *Table) SetRow(r int, data []byte) {
func (t *Table) setRow(r int, data []byte) {
if len(data) != t.cols*t.valueLen {
panic("data have invalid length")
}
......@@ -68,34 +58,26 @@ func (t *Table) SetRow(r int, data []byte) {
copy(t.data[index:index+t.cols*t.valueLen], data)
}
func (t *Table) GetRow(r int) []byte {
func (t *Table) getRow(r int) []byte {
index := t.index(r, 0)
return t.data[index : index+t.cols*t.valueLen]
}
// func (t *Table) SetValueOld(data []byte, i, s int) {
// valStart := i*t.valueLen + s*t.valueLen
// copy(t.data[valStart:valStart+t.valueLen], data)
// }
// func (t *Table) GetValueOld(i, s int) []byte {
// valStart := i*t.valueLen + s*t.valueLen
// return t.data[valStart : valStart+t.valueLen]
// }
func (t *Table) getRows(r, count int) []byte {
index := t.index(r, 0)
return t.data[index : index+count*t.cols*t.valueLen]
}
// func (t *Table) SetRowOld(data [][]byte, i int) {
// for j, val := range data {
// t.SetValueOld(val, i, j)
// }
// }
func (t *Table) SetData(data []byte) {
if len(data) != len(t.data) {
panic("data have invalid length")
}
t.data = data
}
// func (t *Table) GetRowOld(i int) [][]byte {
// values := make([][]byte, 0)
// for j := 0; j < t.cols; j++ {
// values = append(values, t.GetValueOld(i, j))
// }
// return values
// }
func (t *Table) GetData() []byte {
return t.data
}
func (t *Table) getKFirstValues(numInputWire int) [][][]byte {
values := make([][][]byte, 0)
......@@ -105,101 +87,124 @@ func (t *Table) getKFirstValues(numInputWire int) [][][]byte {
return values
}
func (t *Table) randomizeTable() {
rand.Read(t.data)
func (t *Table) randomizeTable() error {
_, err := rand.Read(t.data)
return err
}
type GCircuit struct {
type Circuit struct {
F *Table
E [][][]byte
D [][]byte
E []*Table
D *Table
}
type Protocol struct {
type Params struct {
G func(a, b []byte, i int) []byte
}
func NewGarbBloodCircuit() (c *GCircuit, err error) {
p := Protocol{G}
numInput := 6
numWires := 11 // Includes numInput
k_bits := 128
func NewGarbBloodCircuit(numInputs, numWires, k int) (c *Circuit, err error) {
p := Params{G}
// 1. Create two random strings for each wire.
K, err := NewTable(numWires, 2, k_bits)
kTable, err := NewTable(numWires, 2, k)
if err != nil {
return
}
err = kTable.randomizeTable()
if err != nil {
return
}
K.randomizeTable()
// 2. Create a garbled table for all C values (4 for each gate)
C, err := NewTable(numWires-numInput, 4, k_bits*2)
// 2. Create a garbled table for all fTable values (4 for each gate)
fTable, err := NewTable(numWires-numInputs*2, 4, k*2)
if err != nil {
return
}
// // 1st layer are OR-gates with b input being NOT
C.SetRow(0, p.garbleGate(K, 0, 3, 6, ORGateWithNot))
C.SetRow(1, p.garbleGate(K, 1, 4, 7, ORGateWithNot))
C.SetRow(2, p.garbleGate(K, 2, 5, 8, ORGateWithNot))
// Create F
fTable.setRow(0, p.garbleGate(kTable, 0, 3, 6, ORGateWithNot))
fTable.setRow(1, p.garbleGate(kTable, 1, 4, 7, ORGateWithNot))
fTable.setRow(2, p.garbleGate(kTable, 2, 5, 8, ORGateWithNot))
fTable.setRow(3, p.garbleGate(kTable, 6, 7, 9, ANDGate))
fTable.setRow(4, p.garbleGate(kTable, 8, 9, 10, ANDGate))
C.SetRow(3, p.garbleGate(K, 6, 7, 9, ANDGate))
C.SetRow(4, p.garbleGate(K, 8, 9, 10, ANDGate))
// Create e
e := make([]*Table, 2)
e[0], _ = NewTable(numInputs, 2, k)
e[0].SetData(kTable.getRows(0, numInputs))
c = &GCircuit{
F: C,
E: K.getKFirstValues(numInput),
// D: K.GetRowOld(numWires),
e[1], _ = NewTable(numInputs, 2, k)
e[1].SetData(kTable.getRows(numInputs, numInputs))
// Create d
d, _ := NewTable(1, 2, k)
d.SetData(kTable.getRow(numWires - 1))
c = &Circuit{
F: fTable,
E: e,
D: d,
}
return
}
// Encode x using the e table. If ey == true, we use the second
// half of e, otherwise first half
func Encode(e *Table, x []bool, rowOffset int) (X []byte) {
X = make([]byte, 0)
for i, b := range x {
if b {
X = append(X, e.getValue(i+rowOffset, 1)...)
} else {
X = append(X, e.getValue(i+rowOffset, 0)...)
}
}
if len(x)*e.valueLen != len(X) {
panic("Wrong length of X")
}
// X = make([][]byte, 0)
// for i, val := range x {
// if val {
// X = append(X, e[i][1])
// } else {
// X = append(X, e[i][0])
// }
// }
return
}
//Creates a row (C_0^i, C_1^i, C_2^i ,C_3^i) for the garbled table, where
func (p *Protocol) garbleGate(K *Table, Li, Ri, out int, c func(a, b int) int) (cRow []byte) {
cRow = make([]byte, K.valueLen*4)
// Random permutation
perm := cryptoUtil.Perm(4)
func (p *Params) garbleGate(K *Table, Li, Ri, out int, c func(a, b int) int) []byte {
cRow := make([]byte, K.valueLen*8) // because 4 values of double length
perm := cryptoUtil.Perm(4) // Random permutation
// for (a,b) in {0,1} x {0,1}
for a := 0; a <= 1; a++ {
for b := 0; b <= 1; b++ {
left := p.G(K.GetValue(Li, a), K.GetValue(Ri, b), out)
left := p.G(K.getValue(Li, a), K.getValue(Ri, b), out)
right := make([]byte, K.valueLen*2)
copy(right[:K.valueLen], K.GetValue(out, c(a, b)))
// right = append(right, make([]byte, 16)...)
copy(right[:K.valueLen], K.getValue(out, c(a, b)))
// The index to write this value to. This depends in the row permutation.
rowI := perm[a*2+b] * K.valueLen
util.XOR(left, left, right) // XOR with left as destination.
copy(cRow[rowI:rowI+K.valueLen], left)
// cRow[rowI : rowI+K.valueLen] =
// cVals[a*2+b] = util.XOR(left, right)
}
}
return
return cRow
}
func G(A, B []byte, i int) []byte {
hash := sha256.New()
hash.Write(A)
hash.Write(append(B, byte(i)))
hash.Write(B)
hash.Write([]byte{byte(i)})
return hash.Sum(nil)
}
func Enc(e [][][]byte, x []bool) (X [][]byte) {
X = make([][]byte, 0)
for i, val := range x {
if val {
X = append(X, e[i][1])
} else {
X = append(X, e[i][0])
}
}
return
}
func Eval(C *Table, X [][]byte) ([]byte, error) {
var K = X
var res []byte
......@@ -230,45 +235,45 @@ func Decode(d [][]byte, Z []byte) bool {
// Circuit
////////////////////////////////////////////////////////////////
type Circuit struct {
wires []int
gates []Gate
outputWire int
numInput int
}
// type Circuit struct {
// wires []int
// gates []Gate
// outputWire int
// numInput int
// }
func NewCircuit(numInput int) *Circuit {
return &Circuit{
wires: make([]int, numInput),
gates: make([]Gate, 0),
numInput: numInput,
}
}
// func NewCircuit(numInput int) *Circuit {
// return &Circuit{
// wires: make([]int, numInput),
// gates: make([]Gate, 0),
// numInput: numInput,
// }
// }
func (c *Circuit) getInputWires(i int) (Li, Ri int) {
return i, c.numInput + i
}
// func (c *Circuit) getInputWires(i int) (Li, Ri int) {
// return i, c.numInput + i
// }
func (c *Circuit) AddGate(Li, Ri int, f func(a, b int) int) *Gate {
c.wires = append(c.wires, 0)
g := newGate(Li, Ri, len(c.wires)-1, f)
c.gates = append(c.gates, *g)
return g
}
// func (c *Circuit) AddGate(Li, Ri int, f func(a, b int) int) *Gate {
// c.wires = append(c.wires, 0)
// g := newGate(Li, Ri, len(c.wires)-1, f)
// c.gates = append(c.gates, *g)
// return g