Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions benches/bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,13 +160,10 @@ fn _bench_client_query(
mod bench_utils {
use rand_core::{OsRng, RngCore};
pub fn generate_db_eles(num_eles: usize, ele_byte_len: usize) -> Vec<String> {
let mut eles = Vec::with_capacity(num_eles);
for _ in 0..num_eles {
(0..num_eles).map(|_| {
let mut ele = vec![0u8; ele_byte_len];
OsRng.fill_bytes(&mut ele);
let ele_str = base64::encode(ele);
eles.push(ele_str);
}
eles
base64::encode(ele)
}).collect()
}
}
38 changes: 18 additions & 20 deletions src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,7 @@ impl Shard {
.map(|i| self.db.vec_mult(q, i))
.collect(),
);
let ser = bincode::serialize(&resp);

Ok(ser?)
Ok(bincode::serialize(&resp)?)
}

/// Returns the database
Expand All @@ -86,11 +84,10 @@ impl Shard {
&self.base_params
}

pub fn into_row_iter(&self) -> std::vec::IntoIter<std::string::String> {
(0..self.get_db().get_matrix_height())
.map(|i| self.get_db().get_db_entry(i))
.collect::<Vec<String>>()
.into_iter()
pub fn into_row_iter(&self) -> impl Iterator<Item = String> {
let db = self.get_db();
(0..db.get_matrix_height())
.map(|i| db.get_db_entry(i))
}
}

Expand Down Expand Up @@ -127,8 +124,7 @@ impl QueryParams {
}
self.used = true;
let query_indicator = get_rounding_factor(self.plaintext_bits);
let mut lhs = Vec::new();
lhs.clone_from(&self.lhs.clone());
let mut lhs = self.lhs.clone();
let (result, check) = lhs[row_index].overflowing_add(query_indicator);
if !check {
lhs[row_index] = result;
Expand Down Expand Up @@ -166,9 +162,14 @@ impl Response {
let plaintext_size = get_plaintext_size(qp.plaintext_bits);

// perform division and rounding
(0..Database::get_matrix_width(qp.elem_size, qp.plaintext_bits))
.map(|i| {
let unscaled_res = self.0[i].wrapping_sub(qp.rhs[i]);
let rg = 0..Database::get_matrix_width(qp.ele_size, qp.plaintext_bits);
self.0[rg.clone()]
.iter()
.copied()
.zip(qp.rhs[rg].iter().copied())
.into_iter()
.map(|(x, y)| x.wrapping_sub(y))
.map(|unscaled_res| {
let scaled_res = unscaled_res / rounding_factor;
let scaled_rem = unscaled_res % rounding_factor;
let mut rounded_res = scaled_res;
Expand Down Expand Up @@ -264,13 +265,10 @@ mod tests {

// This will generate random elements for test databases
fn generate_db_elems(num_elems: usize, elem_byte_len: usize) -> Vec<String> {
let mut elems = Vec::with_capacity(num_elems);
for _ in 0..num_elems {
let mut elem = vec![0u8; elem_byte_len];
(0..num_eles).map(|_| {
let mut elem = vec![0u8; ele_byte_len];
OsRng.fill_bytes(&mut elem);
let elem_str = base64::encode(elem);
elems.push(elem_str);
}
elems
base64::encode(elem)
}).collect()
}
}
51 changes: 19 additions & 32 deletions src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ impl Database {

/// Returns the width of the DB matrix
pub fn get_matrix_width_self(&self) -> usize {
Database::get_matrix_width(self.get_elem_size(), self.get_plaintext_bits())
Database::get_matrix_width(self.elem_size, self.plaintext_bits)
}

/// Get the matrix size
Expand Down Expand Up @@ -149,13 +149,7 @@ impl BaseParams {
let lhs =
swap_matrix_fmt(&generate_lwe_matrix_from_seed(public_seed, dim, m));
(0..db.get_matrix_width_self())
.map(|i| {
let mut col = Vec::with_capacity(m);
for r in &lhs {
col.push(db.vec_mult(r, i));
}
col
})
.map(|i| lhs.iter().map(|r| db.vec_mult(r, i)).collect())
.collect()
}

Expand All @@ -170,9 +164,9 @@ impl BaseParams {

/// Computes c = s*(A*DB) using the RHS of the public parameters
pub fn mult_right(&self, s: &[u32]) -> ResultBoxedError<Vec<u32>> {
let cols = &self.rhs;
(0..cols.len())
.map(|i| vec_mult_u32_u32(s, &cols[i]))
self.rhs
.iter()
.map(|i| vec_mult_u32_u32(s, i))
.collect()
}

Expand Down Expand Up @@ -206,10 +200,10 @@ impl CommonParams {
/// Computes b = s*A + e using the seed used to generate the matrix of
/// the public parameters
pub fn mult_left(&self, s: &[u32]) -> ResultBoxedError<Vec<u32>> {
let cols = self.as_matrix();
(0..cols.len())
self.0
.iter()
.map(|i| {
let s_a = vec_mult_u32_u32(s, &cols[i])?;
let s_a = vec_mult_u32_u32(s, i)?;
let e = random_ternary();
Ok(s_a.wrapping_add(e))
})
Expand All @@ -233,24 +227,17 @@ fn construct_rows(
plaintext_bits: usize,
) -> ResultBoxedError<Vec<Vec<u32>>> {
let row_width = Database::get_matrix_width(elem_size, plaintext_bits);

let result = (0..m).map(|i| -> ResultBoxedError<Vec<u32>> {
let mut row = Vec::with_capacity(row_width);
let data = &elements[i];
let bytes = base64::decode(data)?;
let bits = bytes_to_bits_le(&bytes);
for i in 0..row_width {
let end_bound = (i + 1) * plaintext_bits;
if end_bound < bits.len() {
row.push(bits_to_u32_le(&bits[i * plaintext_bits..end_bound])?);
} else {
row.push(bits_to_u32_le(&bits[i * plaintext_bits..])?);
}
}
Ok(row)
});

result.collect()
elements.iter()
.take(m)
.map(|data| {
let bytes = base64::decode(&data)?;
let bits = bytes_to_bits_le(&bytes);
bits.chunks(plaintext_bits)
.take(row_width)
.map(|bitspl| bits_to_u32_le(bitspl).map_err(Into::into))
.collect::<ResultBoxedError<Vec<u32>>>()
})
.collect()
}

fn generate_seed() -> [u8; 32] {
Expand Down
137 changes: 64 additions & 73 deletions src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,16 +57,11 @@ pub mod matrices {
lwe_dim: usize,
width: usize,
) -> Vec<Vec<u32>> {
let mut a = Vec::with_capacity(width);
use core::iter::repeat_with;
let mut rng = get_seeded_rng(seed);
for _ in 0..width {
let mut v = Vec::with_capacity(lwe_dim);
for _ in 0..lwe_dim {
v.push(rng.next_u32());
}
a.push(v);
}
a
repeat_with(||
repeat_with(|| rng.next_u32()).take(lwe_dim).collect()
).take(width).collect()
}

/// Multiplies a u32 vector with a u32 column vector
Expand All @@ -80,11 +75,10 @@ pub mod matrices {
col.len(),
))));
}
let mut acc = 0u32;
for i in 0..row.len() {
acc = acc.wrapping_add(row[i].wrapping_mul(col[i]));
}
Ok(acc)
Ok(row.iter()
.zip(col.iter())
.map(|(&x, &y)| x.wrapping_mul(y))
.fold(0u32, |acc, i| acc.wrapping_add(i)))
}

/// Returns a seeded RNG for sampling values
Expand Down Expand Up @@ -115,23 +109,19 @@ pub mod matrices {
}
// Now we return {0,1,-1} depending on whether the sampled value
// sits in the first, second or third sampling interval
let mut tern = 0;
if val > TERNARY_INTERVAL_SIZE && val <= TERNARY_INTERVAL_SIZE * 2 {
tern = 1;
1
} else if val > TERNARY_INTERVAL_SIZE * 2 {
tern = u32::MAX;
u32::MAX
} else {
0
}
tern
}

/// Simulates a ternary error vector of width size by sampling randomly,
/// using rejection sampling, from {0,1,u32::MAX}
pub fn random_ternary_vector(width: usize) -> Vec<u32> {
let mut row = Vec::new();
for _ in 0..width {
row.push(random_ternary());
}
row
(0..width).map(|_| random_ternary()).collect()
}
}

Expand All @@ -140,52 +130,60 @@ pub mod format {
use crate::errors::ErrorUnexpectedInputSize;
use std::convert::TryInto;

fn u8_to_bits_le(byte: u8) -> Vec<bool> {
let mut ret = Vec::new();
for i in 0..8 {
ret.push(2u8.pow(i as u32) & byte > 0);
}
ret
fn u8_to_bits_le(byte: u8) -> [bool; 8] {
[
2u8.pow(0) & byte != 0,
2u8.pow(1) & byte != 0,
2u8.pow(2) & byte != 0,
2u8.pow(3) & byte != 0,

2u8.pow(4) & byte != 0,
2u8.pow(5) & byte != 0,
2u8.pow(6) & byte != 0,
2u8.pow(7) & byte != 0,
]
}

fn bits_to_u8_le(xs: &[bool]) -> u8 {
assert!(xs.len() <= 8);
xs.iter()
.enumerate()
.filter(|(_, &bit)| bit)
.map(|(i, _)| 2u8.pow(i as u32))
.sum()
}

pub fn u32_to_bits_le(x: u32, bit_len: usize) -> Vec<bool> {
let bytes = x.to_le_bytes();
let mut bits = Vec::with_capacity(bytes.len());
for byte in bytes {
bits.extend(u8_to_bits_le(byte));
}
bits[..bit_len].to_vec()
x.to_le_bytes()
.into_iter()
.flat_map(u8_to_bits_le)
.take(bit_len)
.collect()
}

pub fn bits_to_bytes_le(bits: &[bool]) -> Vec<u8> {
let mut bytes = vec![0u8; (bits.len() + 7) / 8];
for (i, &bit) in bits.iter().enumerate() {
if bit {
let idx = ((i as f64) / 8f64).floor() as usize;
let exp = (i % 8) as u32;
bytes[idx] += 2u8.pow(exp);
}
}
bytes
bits.chunks(8)
.map(|bits8| bits_to_u8_le(bits8))
.collect()
}

pub fn bytes_to_bits_le(bytes: &[u8]) -> Vec<bool> {
bytes
.iter()
.map(|b| u8_to_bits_le(*b))
.collect::<Vec<Vec<bool>>>()
.iter()
.fold(Vec::new(), |mut acc, next| {
acc.extend(next);
acc
})
bytes.iter()
.copied()
.flat_map(u8_to_bits_le)
.collect()
}

pub fn bits_to_u32_le(
bits: &[bool],
) -> Result<u32, ErrorUnexpectedInputSize> {
let mut bytes = bits_to_bytes_le(bits);
let u32_len = std::mem::size_of::<u32>();
bytes_to_u32_le(bits_to_bytes_le(bits))
}

pub fn bytes_to_u32_le(
mut bytes: Vec<u8>,
) -> Result<u32, ErrorUnexpectedInputSize> {
let u32_len = core::mem::size_of::<u32>();
let byte_len = bytes.len();
if byte_len > u32_len {
return Err(ErrorUnexpectedInputSize::new(format!(
Expand All @@ -202,17 +200,11 @@ pub mod format {
pub fn u32_sized_bytes_from_vec(
bytes: Vec<u8>,
) -> Result<[u8; 4], ErrorUnexpectedInputSize> {
let sized_vec: [u8; 4] = match bytes.try_into() {
Ok(b) => b,
Err(e) => {
return Err(ErrorUnexpectedInputSize::new(format!(
"Unexpected vector size: {:?}",
e,
)))
}
};

Ok(sized_vec)
bytes.try_into()
.map_err(|e| ErrorUnexpectedInputSize::new(format!(
"Unexpected vector size: {:?}",
e,
)))
}

pub fn bytes_from_u32_slice(
Expand All @@ -221,16 +213,15 @@ pub mod format {
total_bit_len: usize,
) -> Vec<u8> {
let remainder = total_bit_len % entry_bit_len;
let mut bits = Vec::with_capacity(entry_bit_len * v.len());
for i in 0..v.len() {
let bits: Vec<_> = v.iter().enumerate().flat_map(|(i, &vi)| {
// We extract either the full amount of bits, or the remainder from
// the last index
if i != v.len() - 1 {
bits.extend(u32_to_bits_le(v[i], entry_bit_len));
u32_to_bits_le(vi, if i != v.len() - 1 {
entry_bit_len
} else {
bits.extend(u32_to_bits_le(v[i], remainder));
}
}
remainder
})
}).collect();
bits_to_bytes_le(&bits)
}

Expand Down