Source code

Revision control

Copy as Markdown

Other Tools

/* This Source Code Form is subject to the terms of the Mozilla Public
* License, v. 2.0. If a copy of the MPL was not distributed with this
* file, You can obtain one at https://mozilla.org/MPL/2.0/. */
use std::error::Error;
use std::io::Cursor;
use log::{debug, warn};
use prio::field::Field128;
use prio::vdaf::prio3::{Prio3Histogram, Prio3Sum, Prio3SumVec};
use prio::vdaf::prio3::{Prio3InputShare, Prio3PublicShare};
use thin_vec::ThinVec;
pub mod types;
use types::HpkeConfig;
use types::PlaintextInputShare;
use types::Report;
use types::ReportID;
use types::ReportMetadata;
use types::Time;
use prio::codec::Encode;
use prio::codec::{decode_u16_items, encode_u32_items};
use prio::vdaf::Client;
use crate::types::HpkeCiphertext;
extern "C" {
pub fn dapHpkeEncryptOneshot(
aKey: *const u8,
aKeyLength: u32,
aInfo: *const u8,
aInfoLength: u32,
aAad: *const u8,
aAadLength: u32,
aPlaintext: *const u8,
aPlaintextLength: u32,
aOutputEncapsulatedKey: &mut ThinVec<u8>,
aOutputShare: &mut ThinVec<u8>,
) -> bool;
}
struct SumMeasurement {
value: u32,
bits: usize,
}
struct SumVecMeasurement<'a> {
value: &'a ThinVec<u32>,
bits: usize,
}
struct HistogramMeasurement {
index: u32,
length: usize,
}
enum Role {
Leader = 2,
Helper = 3,
}
// While Prio allows more than two aggregators, in practice for DAP we only ever
// use a Leader and Helper.
const NUM_AGGREGATORS: u8 = 2;
/// A minimal wrapper around the FFI function which mostly just converts datatypes.
fn hpke_encrypt_wrapper(
plain_share: &Vec<u8>,
aad: &Vec<u8>,
info: &Vec<u8>,
hpke_config: &HpkeConfig,
) -> Result<HpkeCiphertext, Box<dyn std::error::Error>> {
let mut encrypted_share = ThinVec::<u8>::new();
let mut encapsulated_key = ThinVec::<u8>::new();
unsafe {
if !dapHpkeEncryptOneshot(
hpke_config.public_key.as_ptr(),
hpke_config.public_key.len() as u32,
info.as_ptr(),
info.len() as u32,
aad.as_ptr(),
aad.len() as u32,
plain_share.as_ptr(),
plain_share.len() as u32,
&mut encapsulated_key,
&mut encrypted_share,
) {
return Err(Box::from("Encryption failed."));
}
}
Ok(HpkeCiphertext {
config_id: hpke_config.id,
enc: encapsulated_key.to_vec(),
payload: encrypted_share.to_vec(),
})
}
const SEED_SIZE: usize = 16;
fn encode_prio3_shares(
public_share: Prio3PublicShare<SEED_SIZE>,
input_shares: Vec<Prio3InputShare<Field128, SEED_SIZE>>,
) -> Result<(Vec<u8>, Vec<Vec<u8>>), Box<dyn std::error::Error>> {
debug_assert_eq!(input_shares.len(), NUM_AGGREGATORS as usize);
let encoded_input_shares = input_shares
.iter()
.map(|s| s.get_encoded())
.collect::<Result<Vec<_>, _>>()?;
let encoded_public_share = public_share.get_encoded()?;
Ok((encoded_public_share, encoded_input_shares))
}
trait Shardable {
fn shard(
&self,
nonce: &[u8; 16],
) -> Result<(Vec<u8>, Vec<Vec<u8>>), Box<dyn std::error::Error>>;
}
impl Shardable for SumMeasurement {
fn shard(
&self,
nonce: &[u8; 16],
) -> Result<(Vec<u8>, Vec<Vec<u8>>), Box<dyn std::error::Error>> {
let prio = Prio3Sum::new_sum(NUM_AGGREGATORS, self.bits)?;
let (public_share, input_shares) = prio.shard(&(self.value as u128), nonce)?;
encode_prio3_shares(public_share, input_shares)
}
}
impl Shardable for SumVecMeasurement<'_> {
fn shard(
&self,
nonce: &[u8; 16],
) -> Result<(Vec<u8>, Vec<Vec<u8>>), Box<dyn std::error::Error>> {
let chunk_length = prio::vdaf::prio3::optimal_chunk_length(self.bits * self.value.len());
let prio =
Prio3SumVec::new_sum_vec(NUM_AGGREGATORS, self.bits, self.value.len(), chunk_length)?;
let measurement: Vec<u128> = self.value.iter().map(|e| (*e as u128)).collect();
let (public_share, input_shares) = prio.shard(&measurement, nonce)?;
encode_prio3_shares(public_share, input_shares)
}
}
impl Shardable for HistogramMeasurement {
fn shard(
&self,
nonce: &[u8; 16],
) -> Result<(Vec<u8>, Vec<Vec<u8>>), Box<dyn std::error::Error>> {
let chunk_length = prio::vdaf::prio3::optimal_chunk_length(self.length);
let prio = Prio3Histogram::new_histogram(NUM_AGGREGATORS, self.length, chunk_length)?;
let (public_share, input_shares) = prio.shard(&(self.index as usize), nonce)?;
encode_prio3_shares(public_share, input_shares)
}
}
// Decode advertised HPKE configurations and pick a supported mode.
fn select_hpke_config(encoded: &ThinVec<u8>) -> Result<HpkeConfig, Box<dyn Error>> {
let hpke_configs: Vec<HpkeConfig> = decode_u16_items(&(), &mut Cursor::new(encoded))?;
// Our supported HPKE algorithms with constants from RFC-9180.
const SUPPORTED_KEM: u16 = 0x20; // DHKEM(X25519, HKDF-SHA256)
const SUPPORTED_KDF: u16 = 0x01; // HKDF-SHA256
const SUPPORTED_AEAD: u16 = 0x01; // AES-128-GCM
for config in hpke_configs {
if config.kem_id == SUPPORTED_KEM
&& config.kdf_id == SUPPORTED_KDF
&& config.aead_id == SUPPORTED_AEAD
{
return Ok(config);
}
}
Err("No suitable HPKE config found.".into())
}
/// This function creates a full report - ready to send - for a measurement.
///
/// To do that it also needs the HPKE configurations for the endpoints and some
/// additional data which is part of the authentication.
fn get_dap_report_internal<T: Shardable>(
leader_hpke_config_encoded: &ThinVec<u8>,
helper_hpke_config_encoded: &ThinVec<u8>,
measurement: &T,
task_id: &ThinVec<u8>,
time_precision: u64,
) -> Result<Report, Box<dyn std::error::Error>> {
let leader_hpke_config = select_hpke_config(leader_hpke_config_encoded)?;
let helper_hpke_config = select_hpke_config(helper_hpke_config_encoded)?;
let report_id = ReportID::generate();
let (encoded_public_share, encoded_input_shares) = measurement.shard(report_id.as_ref())?;
let plaintext_input_shares: Vec<Vec<u8>> = encoded_input_shares
.into_iter()
.map(|encoded_input_share| {
PlaintextInputShare {
extensions: Vec::new(),
payload: encoded_input_share,
}
.get_encoded()
})
.collect::<Result<Vec<_>, _>>()?;
debug!("Plaintext input shares computed.");
let time = Time::generate(time_precision);
let metadata = ReportMetadata { report_id, time };
// This quote from the standard describes which info and aad to use for the encryption:
// enc, payload = SealBase(pk,
// "dap-09 input share" || 0x01 || server_role,
// input_share_aad, plaintext_input_share)
let mut info = b"dap-09 input share\x01".to_vec();
assert_eq!(task_id.len(), 32);
let mut aad = Vec::from(task_id.as_ref());
metadata.encode(&mut aad)?;
encode_u32_items(&mut aad, &(), &encoded_public_share)?;
info.push(Role::Leader as u8);
let leader_payload =
hpke_encrypt_wrapper(&plaintext_input_shares[0], &aad, &info, &leader_hpke_config)?;
debug!("Leader payload encrypted.");
info.pop();
info.push(Role::Helper as u8);
let helper_payload =
hpke_encrypt_wrapper(&plaintext_input_shares[1], &aad, &info, &helper_hpke_config)?;
debug!("Helper payload encrypted.");
info.pop();
Ok(Report {
metadata,
public_share: encoded_public_share,
leader_encrypted_input_share: leader_payload,
helper_encrypted_input_share: helper_payload,
})
}
/// Wraps the function above with minor C interop.
/// Mostly it turns any error result into a return value of false.
#[no_mangle]
pub extern "C" fn dapGetReportPrioSum(
leader_hpke_config_encoded: &ThinVec<u8>,
helper_hpke_config_encoded: &ThinVec<u8>,
measurement: u32,
task_id: &ThinVec<u8>,
bits: u32,
time_precision: u64,
out_report: &mut ThinVec<u8>,
) -> bool {
let Ok(report) = get_dap_report_internal::<SumMeasurement>(
leader_hpke_config_encoded,
helper_hpke_config_encoded,
&SumMeasurement {
value: measurement,
bits: bits as usize,
},
task_id,
time_precision,
) else {
warn!("Creating report failed!");
return false;
};
let Ok(encoded_report) = report.get_encoded() else {
warn!("Encoding report failed!");
return false;
};
out_report.extend(encoded_report);
true
}
#[no_mangle]
pub extern "C" fn dapGetReportPrioSumVec(
leader_hpke_config_encoded: &ThinVec<u8>,
helper_hpke_config_encoded: &ThinVec<u8>,
measurement: &ThinVec<u32>,
task_id: &ThinVec<u8>,
bits: u32,
time_precision: u64,
out_report: &mut ThinVec<u8>,
) -> bool {
let Ok(report) = get_dap_report_internal::<SumVecMeasurement>(
leader_hpke_config_encoded,
helper_hpke_config_encoded,
&SumVecMeasurement {
value: measurement,
bits: bits as usize,
},
task_id,
time_precision,
) else {
warn!("Creating report failed!");
return false;
};
let Ok(encoded_report) = report.get_encoded() else {
warn!("Encoding report failed!");
return false;
};
out_report.extend(encoded_report);
true
}
#[no_mangle]
pub extern "C" fn dapGetReportPrioHistogram(
leader_hpke_config_encoded: &ThinVec<u8>,
helper_hpke_config_encoded: &ThinVec<u8>,
measurement: u32,
task_id: &ThinVec<u8>,
length: u32,
time_precision: u64,
out_report: &mut ThinVec<u8>,
) -> bool {
let Ok(report) = get_dap_report_internal::<HistogramMeasurement>(
leader_hpke_config_encoded,
helper_hpke_config_encoded,
&HistogramMeasurement {
index: measurement,
length: length as usize,
},
task_id,
time_precision,
) else {
warn!("Creating report failed!");
return false;
};
let Ok(encoded_report) = report.get_encoded() else {
warn!("Encoding report failed!");
return false;
};
out_report.extend(encoded_report);
true
}