Skip to content

Commit 8e3e3fc

Browse files
alexclaude
andauthored
Add DHKEM(P-256, HKDF-SHA256) support to HPKE implementation (#14398)
Add P256 KEM variant alongside the existing X25519, following the same abstraction patterns used for KDF and AEAD. KEM-specific operations (key generation, public key serialization/deserialization, DH exchange) are methods on the KEM enum, keeping encap/decap unified with no duplication. Includes key type validation against Python ABCs (TypeError on mismatch), RFC 9180 test vector validation for kem_id=0x0010, and comprehensive wrong-key/wrong-type/wrong-curve tests. https://claude.ai/code/session_01FJG426sLhnaeMytWntj9aJ Co-authored-by: Claude <[email protected]>
1 parent 1809950 commit 8e3e3fc

File tree

5 files changed

+312
-43
lines changed

5 files changed

+312
-43
lines changed

src/cryptography/hazmat/bindings/_rust/openssl/hpke.pyi

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
33
# for complete details.
44

5-
from cryptography.hazmat.primitives.asymmetric import x25519
5+
from cryptography.hazmat.primitives.asymmetric import ec, x25519
66
from cryptography.utils import Buffer
77

88
class KEM:
99
X25519: KEM
10+
P256: KEM
1011

1112
class KDF:
1213
HKDF_SHA256: KDF
@@ -23,27 +24,27 @@ class Suite:
2324
def encrypt(
2425
self,
2526
plaintext: Buffer,
26-
public_key: x25519.X25519PublicKey,
27+
public_key: x25519.X25519PublicKey | ec.EllipticCurvePublicKey,
2728
info: Buffer | None = None,
2829
) -> bytes: ...
2930
def decrypt(
3031
self,
3132
ciphertext: Buffer,
32-
private_key: x25519.X25519PrivateKey,
33+
private_key: x25519.X25519PrivateKey | ec.EllipticCurvePrivateKey,
3334
info: Buffer | None = None,
3435
) -> bytes: ...
3536

3637
def _encrypt_with_aad(
3738
suite: Suite,
3839
plaintext: Buffer,
39-
public_key: x25519.X25519PublicKey,
40+
public_key: x25519.X25519PublicKey | ec.EllipticCurvePublicKey,
4041
info: Buffer | None = None,
4142
aad: Buffer | None = None,
4243
) -> bytes: ...
4344
def _decrypt_with_aad(
4445
suite: Suite,
4546
ciphertext: Buffer,
46-
private_key: x25519.X25519PrivateKey,
47+
private_key: x25519.X25519PrivateKey | ec.EllipticCurvePrivateKey,
4748
info: Buffer | None = None,
4849
aad: Buffer | None = None,
4950
) -> bytes: ...

src/rust/src/backend/ec.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ pub(crate) fn public_key_from_pkey(
128128

129129
#[pyo3::pyfunction]
130130
#[pyo3(signature = (curve, backend=None))]
131-
fn generate_private_key(
131+
pub(crate) fn generate_private_key(
132132
py: pyo3::Python<'_>,
133133
curve: pyo3::Bound<'_, pyo3::PyAny>,
134134
backend: Option<pyo3::Bound<'_, pyo3::PyAny>>,
@@ -171,7 +171,7 @@ fn derive_private_key(
171171
}
172172

173173
#[pyo3::pyfunction]
174-
fn from_public_bytes(
174+
pub(crate) fn from_public_bytes(
175175
py: pyo3::Python<'_>,
176176
py_curve: pyo3::Bound<'_, pyo3::PyAny>,
177177
data: &[u8],

src/rust/src/backend/hpke.rs

Lines changed: 195 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
use pyo3::types::{PyAnyMethods, PyBytesMethods};
66

77
use crate::backend::aead::{AesGcm, ChaCha20Poly1305};
8+
use crate::backend::ec;
89
use crate::backend::kdf::{hkdf_extract, HkdfExpand};
910
use crate::backend::x25519;
1011
use crate::buf::CffiBuf;
@@ -18,6 +19,10 @@ mod kem_params {
1819
pub const X25519_ID: u16 = 0x0020;
1920
pub const X25519_NSECRET: usize = 32;
2021
pub const X25519_NENC: usize = 32;
22+
23+
pub const P256_ID: u16 = 0x0010;
24+
pub const P256_NSECRET: usize = 32;
25+
pub const P256_NENC: usize = 65;
2126
}
2227

2328
mod kdf_params {
@@ -54,6 +59,170 @@ mod aead_params {
5459
#[derive(Clone, PartialEq, Eq, Hash)]
5560
pub(crate) enum KEM {
5661
X25519,
62+
P256,
63+
}
64+
65+
impl KEM {
66+
fn id(&self) -> u16 {
67+
match self {
68+
KEM::X25519 => kem_params::X25519_ID,
69+
KEM::P256 => kem_params::P256_ID,
70+
}
71+
}
72+
73+
fn secret_length(&self) -> usize {
74+
match self {
75+
KEM::X25519 => kem_params::X25519_NSECRET,
76+
KEM::P256 => kem_params::P256_NSECRET,
77+
}
78+
}
79+
80+
fn enc_length(&self) -> usize {
81+
match self {
82+
KEM::X25519 => kem_params::X25519_NENC,
83+
KEM::P256 => kem_params::P256_NENC,
84+
}
85+
}
86+
87+
fn check_public_key(
88+
&self,
89+
py: pyo3::Python<'_>,
90+
key: &pyo3::Bound<'_, pyo3::PyAny>,
91+
) -> CryptographyResult<()> {
92+
match self {
93+
KEM::X25519 => {
94+
if !key.is_instance(&types::X25519_PUBLIC_KEY.get(py)?)? {
95+
return Err(CryptographyError::from(
96+
pyo3::exceptions::PyTypeError::new_err(
97+
"Expected X25519PublicKey for KEM.X25519",
98+
),
99+
));
100+
}
101+
}
102+
KEM::P256 => {
103+
if !key.is_instance(&types::ELLIPTIC_CURVE_PUBLIC_KEY.get(py)?)? {
104+
return Err(CryptographyError::from(
105+
pyo3::exceptions::PyTypeError::new_err(
106+
"Expected EllipticCurvePublicKey for KEM.P256",
107+
),
108+
));
109+
}
110+
let curve = key.getattr(pyo3::intern!(py, "curve"))?;
111+
if !curve.is_instance(&types::SECP256R1.get(py)?)? {
112+
return Err(CryptographyError::from(
113+
pyo3::exceptions::PyTypeError::new_err(
114+
"Expected EllipticCurvePublicKey on secp256r1 for KEM.P256",
115+
),
116+
));
117+
}
118+
}
119+
}
120+
Ok(())
121+
}
122+
123+
fn check_private_key(
124+
&self,
125+
py: pyo3::Python<'_>,
126+
key: &pyo3::Bound<'_, pyo3::PyAny>,
127+
) -> CryptographyResult<()> {
128+
match self {
129+
KEM::X25519 => {
130+
if !key.is_instance(&types::X25519_PRIVATE_KEY.get(py)?)? {
131+
return Err(CryptographyError::from(
132+
pyo3::exceptions::PyTypeError::new_err(
133+
"Expected X25519PrivateKey for KEM.X25519",
134+
),
135+
));
136+
}
137+
}
138+
KEM::P256 => {
139+
if !key.is_instance(&types::ELLIPTIC_CURVE_PRIVATE_KEY.get(py)?)? {
140+
return Err(CryptographyError::from(
141+
pyo3::exceptions::PyTypeError::new_err(
142+
"Expected EllipticCurvePrivateKey for KEM.P256",
143+
),
144+
));
145+
}
146+
let curve = key.getattr(pyo3::intern!(py, "curve"))?;
147+
if !curve.is_instance(&types::SECP256R1.get(py)?)? {
148+
return Err(CryptographyError::from(
149+
pyo3::exceptions::PyTypeError::new_err(
150+
"Expected EllipticCurvePrivateKey on secp256r1 for KEM.P256",
151+
),
152+
));
153+
}
154+
}
155+
}
156+
Ok(())
157+
}
158+
159+
fn generate_key<'p>(
160+
&self,
161+
py: pyo3::Python<'p>,
162+
) -> CryptographyResult<pyo3::Bound<'p, pyo3::PyAny>> {
163+
match self {
164+
KEM::X25519 => Ok(pyo3::Bound::new(py, x25519::generate_key()?)?.into_any()),
165+
KEM::P256 => {
166+
let secp256r1 = types::SECP256R1.get(py)?.call0()?;
167+
Ok(
168+
pyo3::Bound::new(py, ec::generate_private_key(py, secp256r1, None)?)?
169+
.into_any(),
170+
)
171+
}
172+
}
173+
}
174+
175+
fn serialize_public_key<'p>(
176+
&self,
177+
py: pyo3::Python<'p>,
178+
pk: &pyo3::Bound<'p, pyo3::PyAny>,
179+
) -> CryptographyResult<pyo3::Bound<'p, pyo3::types::PyBytes>> {
180+
match self {
181+
KEM::X25519 => Ok(pk
182+
.call_method0(pyo3::intern!(py, "public_bytes_raw"))?
183+
.extract()?),
184+
KEM::P256 => Ok(pk
185+
.call_method1(
186+
pyo3::intern!(py, "public_bytes"),
187+
(
188+
crate::serialization::Encoding::X962,
189+
crate::serialization::PublicFormat::UncompressedPoint,
190+
),
191+
)?
192+
.extract()?),
193+
}
194+
}
195+
196+
fn deserialize_public_key<'p>(
197+
&self,
198+
py: pyo3::Python<'p>,
199+
data: &[u8],
200+
) -> CryptographyResult<pyo3::Bound<'p, pyo3::PyAny>> {
201+
match self {
202+
KEM::X25519 => Ok(pyo3::Bound::new(py, x25519::from_public_bytes(data)?)?.into_any()),
203+
KEM::P256 => {
204+
let secp256r1 = types::SECP256R1.get(py)?.call0()?;
205+
Ok(pyo3::Bound::new(py, ec::from_public_bytes(py, secp256r1, data)?)?.into_any())
206+
}
207+
}
208+
}
209+
210+
fn exchange<'p>(
211+
&self,
212+
py: pyo3::Python<'p>,
213+
private_key: &pyo3::Bound<'p, pyo3::PyAny>,
214+
public_key: &pyo3::Bound<'p, pyo3::PyAny>,
215+
) -> CryptographyResult<pyo3::Bound<'p, pyo3::PyAny>> {
216+
match self {
217+
KEM::X25519 => {
218+
Ok(private_key.call_method1(pyo3::intern!(py, "exchange"), (public_key,))?)
219+
}
220+
KEM::P256 => {
221+
let ecdh = types::ECDH.get(py)?.call0()?;
222+
Ok(private_key.call_method1(pyo3::intern!(py, "exchange"), (&ecdh, public_key))?)
223+
}
224+
}
225+
}
57226
}
58227

59228
#[allow(clippy::upper_case_acronyms)]
@@ -146,6 +315,7 @@ impl AEAD {
146315
#[pyo3::pyclass(frozen, module = "cryptography.hazmat.bindings._rust.openssl.hpke")]
147316
pub(crate) struct Suite {
148317
aead: AEAD,
318+
kem: KEM,
149319
kem_suite_id: [u8; 5],
150320
hpke_suite_id: [u8; 10],
151321
kdf: KDF,
@@ -218,7 +388,7 @@ impl Suite {
218388
&eae_prk,
219389
b"shared_secret",
220390
kem_context,
221-
kem_params::X25519_NSECRET,
391+
self.kem.secret_length(),
222392
)
223393
}
224394

@@ -230,22 +400,19 @@ impl Suite {
230400
pyo3::Bound<'p, pyo3::types::PyBytes>,
231401
pyo3::Bound<'p, pyo3::types::PyBytes>,
232402
)> {
233-
let sk_e = pyo3::Bound::new(py, x25519::generate_key()?)?;
403+
let sk_e = self.kem.generate_key(py)?;
234404
let pk_e = sk_e.call_method0(pyo3::intern!(py, "public_key"))?;
405+
let pk_e_bytes = self.kem.serialize_public_key(py, &pk_e)?;
406+
let pk_r_bytes = self.kem.serialize_public_key(py, pk_r)?;
235407

236-
let pk_e_bytes: pyo3::Bound<'p, pyo3::types::PyBytes> = pk_e
237-
.call_method0(pyo3::intern!(py, "public_bytes_raw"))?
238-
.extract()?;
239-
240-
let pk_r_bytes = pk_r.call_method0(pyo3::intern!(py, "public_bytes_raw"))?;
241-
let pk_r_raw = pk_r_bytes.extract::<&[u8]>()?;
242-
243-
let dh_result = sk_e.call_method1(pyo3::intern!(py, "exchange"), (pk_r,))?;
408+
let dh_result = self.kem.exchange(py, &sk_e, pk_r)?;
244409
let dh = dh_result.extract::<&[u8]>()?;
245410

246-
let mut kem_context = [0u8; 64];
247-
kem_context[..32].copy_from_slice(pk_e_bytes.as_bytes());
248-
kem_context[32..].copy_from_slice(pk_r_raw);
411+
let mut kem_context =
412+
Vec::with_capacity(pk_e_bytes.as_bytes().len() + pk_r_bytes.as_bytes().len());
413+
kem_context.extend_from_slice(pk_e_bytes.as_bytes());
414+
kem_context.extend_from_slice(pk_r_bytes.as_bytes());
415+
249416
let shared_secret = self.extract_and_expand(py, dh, &kem_context)?;
250417
Ok((shared_secret, pk_e_bytes))
251418
}
@@ -256,19 +423,18 @@ impl Suite {
256423
enc: &[u8],
257424
sk_r: &pyo3::Bound<'_, pyo3::PyAny>,
258425
) -> CryptographyResult<pyo3::Bound<'p, pyo3::types::PyBytes>> {
259-
// Reconstruct pk_e from enc
260-
let pk_e = pyo3::Bound::new(py, x25519::from_public_bytes(enc)?)?;
426+
let pk_e = self.kem.deserialize_public_key(py, enc)?;
261427

262-
let dh_result = sk_r.call_method1(pyo3::intern!(py, "exchange"), (&pk_e,))?;
428+
let dh_result = self.kem.exchange(py, sk_r, &pk_e)?;
263429
let dh = dh_result.extract::<&[u8]>()?;
264430

265431
let pk_rm = sk_r.call_method0(pyo3::intern!(py, "public_key"))?;
266-
let pk_rm_bytes = pk_rm.call_method0(pyo3::intern!(py, "public_bytes_raw"))?;
267-
let pk_rm_raw = pk_rm_bytes.extract::<&[u8]>()?;
432+
let pk_rm_bytes = self.kem.serialize_public_key(py, &pk_rm)?;
433+
434+
let mut kem_context = Vec::with_capacity(enc.len() + pk_rm_bytes.as_bytes().len());
435+
kem_context.extend_from_slice(enc);
436+
kem_context.extend_from_slice(pk_rm_bytes.as_bytes());
268437

269-
let mut kem_context = [0u8; 64];
270-
kem_context[..32].copy_from_slice(enc);
271-
kem_context[32..].copy_from_slice(pk_rm_raw);
272438
self.extract_and_expand(py, dh, &kem_context)
273439
}
274440

@@ -361,6 +527,7 @@ impl Suite {
361527
info: Option<CffiBuf<'_>>,
362528
aad: Option<CffiBuf<'_>>,
363529
) -> CryptographyResult<pyo3::Bound<'p, pyo3::types::PyBytes>> {
530+
self.kem.check_public_key(py, public_key)?;
364531
let info_bytes: &[u8] = info.as_ref().map(|b| b.as_bytes()).unwrap_or(b"");
365532

366533
let (shared_secret, enc) = self.encap(py, public_key)?;
@@ -389,14 +556,15 @@ impl Suite {
389556
info: Option<CffiBuf<'_>>,
390557
aad: Option<CffiBuf<'_>>,
391558
) -> CryptographyResult<pyo3::Bound<'p, pyo3::types::PyBytes>> {
559+
self.kem.check_private_key(py, private_key)?;
392560
let ct_bytes = ciphertext.as_bytes();
393-
if ct_bytes.len() < kem_params::X25519_NENC + self.aead.tag_length() {
561+
if ct_bytes.len() < self.kem.enc_length() + self.aead.tag_length() {
394562
return Err(CryptographyError::from(exceptions::InvalidTag::new_err(())));
395563
}
396564

397565
let info_bytes: &[u8] = info.as_ref().map(|b| b.as_bytes()).unwrap_or(b"");
398566

399-
let (enc, ct) = ct_bytes.split_at(kem_params::X25519_NENC);
567+
let (enc, ct) = ct_bytes.split_at(self.kem.enc_length());
400568

401569
let shared_secret = self
402570
.decap(py, enc, private_key)
@@ -445,20 +613,21 @@ impl Suite {
445613
#[pyo3::pymethods]
446614
impl Suite {
447615
#[new]
448-
fn new(_kem: KEM, kdf: KDF, aead: AEAD) -> CryptographyResult<Suite> {
616+
fn new(kem: KEM, kdf: KDF, aead: AEAD) -> CryptographyResult<Suite> {
449617
// Build suite IDs
450618
let mut kem_suite_id = [0u8; 5];
451619
kem_suite_id[..3].copy_from_slice(b"KEM");
452-
kem_suite_id[3..].copy_from_slice(&kem_params::X25519_ID.to_be_bytes());
620+
kem_suite_id[3..].copy_from_slice(&kem.id().to_be_bytes());
453621

454622
let mut hpke_suite_id = [0u8; 10];
455623
hpke_suite_id[..4].copy_from_slice(b"HPKE");
456-
hpke_suite_id[4..6].copy_from_slice(&kem_params::X25519_ID.to_be_bytes());
624+
hpke_suite_id[4..6].copy_from_slice(&kem.id().to_be_bytes());
457625
hpke_suite_id[6..8].copy_from_slice(&kdf.id().to_be_bytes());
458626
hpke_suite_id[8..10].copy_from_slice(&aead.id().to_be_bytes());
459627

460628
Ok(Suite {
461629
aead,
630+
kem,
462631
kem_suite_id,
463632
hpke_suite_id,
464633
kdf,

0 commit comments

Comments
 (0)