55use pyo3:: types:: { PyAnyMethods , PyBytesMethods } ;
66
77use crate :: backend:: aead:: { AesGcm , ChaCha20Poly1305 } ;
8+ use crate :: backend:: ec;
89use crate :: backend:: kdf:: { hkdf_extract, HkdfExpand } ;
910use crate :: backend:: x25519;
1011use crate :: buf:: CffiBuf ;
@@ -18,6 +19,10 @@ mod kem_params {
1819 pub const X25519_ID : u16 = 0x0020 ;
1920 pub const X25519_NSECRET : usize = 32 ;
2021 pub const X25519_NENC : usize = 32 ;
22+
23+ pub const P256_ID : u16 = 0x0010 ;
24+ pub const P256_NSECRET : usize = 32 ;
25+ pub const P256_NENC : usize = 65 ;
2126}
2227
2328mod kdf_params {
@@ -54,6 +59,170 @@ mod aead_params {
5459#[ derive( Clone , PartialEq , Eq , Hash ) ]
5560pub ( crate ) enum KEM {
5661 X25519 ,
62+ P256 ,
63+ }
64+
65+ impl KEM {
66+ fn id ( & self ) -> u16 {
67+ match self {
68+ KEM :: X25519 => kem_params:: X25519_ID ,
69+ KEM :: P256 => kem_params:: P256_ID ,
70+ }
71+ }
72+
73+ fn secret_length ( & self ) -> usize {
74+ match self {
75+ KEM :: X25519 => kem_params:: X25519_NSECRET ,
76+ KEM :: P256 => kem_params:: P256_NSECRET ,
77+ }
78+ }
79+
80+ fn enc_length ( & self ) -> usize {
81+ match self {
82+ KEM :: X25519 => kem_params:: X25519_NENC ,
83+ KEM :: P256 => kem_params:: P256_NENC ,
84+ }
85+ }
86+
87+ fn check_public_key (
88+ & self ,
89+ py : pyo3:: Python < ' _ > ,
90+ key : & pyo3:: Bound < ' _ , pyo3:: PyAny > ,
91+ ) -> CryptographyResult < ( ) > {
92+ match self {
93+ KEM :: X25519 => {
94+ if !key. is_instance ( & types:: X25519_PUBLIC_KEY . get ( py) ?) ? {
95+ return Err ( CryptographyError :: from (
96+ pyo3:: exceptions:: PyTypeError :: new_err (
97+ "Expected X25519PublicKey for KEM.X25519" ,
98+ ) ,
99+ ) ) ;
100+ }
101+ }
102+ KEM :: P256 => {
103+ if !key. is_instance ( & types:: ELLIPTIC_CURVE_PUBLIC_KEY . get ( py) ?) ? {
104+ return Err ( CryptographyError :: from (
105+ pyo3:: exceptions:: PyTypeError :: new_err (
106+ "Expected EllipticCurvePublicKey for KEM.P256" ,
107+ ) ,
108+ ) ) ;
109+ }
110+ let curve = key. getattr ( pyo3:: intern!( py, "curve" ) ) ?;
111+ if !curve. is_instance ( & types:: SECP256R1 . get ( py) ?) ? {
112+ return Err ( CryptographyError :: from (
113+ pyo3:: exceptions:: PyTypeError :: new_err (
114+ "Expected EllipticCurvePublicKey on secp256r1 for KEM.P256" ,
115+ ) ,
116+ ) ) ;
117+ }
118+ }
119+ }
120+ Ok ( ( ) )
121+ }
122+
123+ fn check_private_key (
124+ & self ,
125+ py : pyo3:: Python < ' _ > ,
126+ key : & pyo3:: Bound < ' _ , pyo3:: PyAny > ,
127+ ) -> CryptographyResult < ( ) > {
128+ match self {
129+ KEM :: X25519 => {
130+ if !key. is_instance ( & types:: X25519_PRIVATE_KEY . get ( py) ?) ? {
131+ return Err ( CryptographyError :: from (
132+ pyo3:: exceptions:: PyTypeError :: new_err (
133+ "Expected X25519PrivateKey for KEM.X25519" ,
134+ ) ,
135+ ) ) ;
136+ }
137+ }
138+ KEM :: P256 => {
139+ if !key. is_instance ( & types:: ELLIPTIC_CURVE_PRIVATE_KEY . get ( py) ?) ? {
140+ return Err ( CryptographyError :: from (
141+ pyo3:: exceptions:: PyTypeError :: new_err (
142+ "Expected EllipticCurvePrivateKey for KEM.P256" ,
143+ ) ,
144+ ) ) ;
145+ }
146+ let curve = key. getattr ( pyo3:: intern!( py, "curve" ) ) ?;
147+ if !curve. is_instance ( & types:: SECP256R1 . get ( py) ?) ? {
148+ return Err ( CryptographyError :: from (
149+ pyo3:: exceptions:: PyTypeError :: new_err (
150+ "Expected EllipticCurvePrivateKey on secp256r1 for KEM.P256" ,
151+ ) ,
152+ ) ) ;
153+ }
154+ }
155+ }
156+ Ok ( ( ) )
157+ }
158+
159+ fn generate_key < ' p > (
160+ & self ,
161+ py : pyo3:: Python < ' p > ,
162+ ) -> CryptographyResult < pyo3:: Bound < ' p , pyo3:: PyAny > > {
163+ match self {
164+ KEM :: X25519 => Ok ( pyo3:: Bound :: new ( py, x25519:: generate_key ( ) ?) ?. into_any ( ) ) ,
165+ KEM :: P256 => {
166+ let secp256r1 = types:: SECP256R1 . get ( py) ?. call0 ( ) ?;
167+ Ok (
168+ pyo3:: Bound :: new ( py, ec:: generate_private_key ( py, secp256r1, None ) ?) ?
169+ . into_any ( ) ,
170+ )
171+ }
172+ }
173+ }
174+
175+ fn serialize_public_key < ' p > (
176+ & self ,
177+ py : pyo3:: Python < ' p > ,
178+ pk : & pyo3:: Bound < ' p , pyo3:: PyAny > ,
179+ ) -> CryptographyResult < pyo3:: Bound < ' p , pyo3:: types:: PyBytes > > {
180+ match self {
181+ KEM :: X25519 => Ok ( pk
182+ . call_method0 ( pyo3:: intern!( py, "public_bytes_raw" ) ) ?
183+ . extract ( ) ?) ,
184+ KEM :: P256 => Ok ( pk
185+ . call_method1 (
186+ pyo3:: intern!( py, "public_bytes" ) ,
187+ (
188+ crate :: serialization:: Encoding :: X962 ,
189+ crate :: serialization:: PublicFormat :: UncompressedPoint ,
190+ ) ,
191+ ) ?
192+ . extract ( ) ?) ,
193+ }
194+ }
195+
196+ fn deserialize_public_key < ' p > (
197+ & self ,
198+ py : pyo3:: Python < ' p > ,
199+ data : & [ u8 ] ,
200+ ) -> CryptographyResult < pyo3:: Bound < ' p , pyo3:: PyAny > > {
201+ match self {
202+ KEM :: X25519 => Ok ( pyo3:: Bound :: new ( py, x25519:: from_public_bytes ( data) ?) ?. into_any ( ) ) ,
203+ KEM :: P256 => {
204+ let secp256r1 = types:: SECP256R1 . get ( py) ?. call0 ( ) ?;
205+ Ok ( pyo3:: Bound :: new ( py, ec:: from_public_bytes ( py, secp256r1, data) ?) ?. into_any ( ) )
206+ }
207+ }
208+ }
209+
210+ fn exchange < ' p > (
211+ & self ,
212+ py : pyo3:: Python < ' p > ,
213+ private_key : & pyo3:: Bound < ' p , pyo3:: PyAny > ,
214+ public_key : & pyo3:: Bound < ' p , pyo3:: PyAny > ,
215+ ) -> CryptographyResult < pyo3:: Bound < ' p , pyo3:: PyAny > > {
216+ match self {
217+ KEM :: X25519 => {
218+ Ok ( private_key. call_method1 ( pyo3:: intern!( py, "exchange" ) , ( public_key, ) ) ?)
219+ }
220+ KEM :: P256 => {
221+ let ecdh = types:: ECDH . get ( py) ?. call0 ( ) ?;
222+ Ok ( private_key. call_method1 ( pyo3:: intern!( py, "exchange" ) , ( & ecdh, public_key) ) ?)
223+ }
224+ }
225+ }
57226}
58227
59228#[ allow( clippy:: upper_case_acronyms) ]
@@ -146,6 +315,7 @@ impl AEAD {
146315#[ pyo3:: pyclass( frozen, module = "cryptography.hazmat.bindings._rust.openssl.hpke" ) ]
147316pub ( crate ) struct Suite {
148317 aead : AEAD ,
318+ kem : KEM ,
149319 kem_suite_id : [ u8 ; 5 ] ,
150320 hpke_suite_id : [ u8 ; 10 ] ,
151321 kdf : KDF ,
@@ -218,7 +388,7 @@ impl Suite {
218388 & eae_prk,
219389 b"shared_secret" ,
220390 kem_context,
221- kem_params :: X25519_NSECRET ,
391+ self . kem . secret_length ( ) ,
222392 )
223393 }
224394
@@ -230,22 +400,19 @@ impl Suite {
230400 pyo3:: Bound < ' p , pyo3:: types:: PyBytes > ,
231401 pyo3:: Bound < ' p , pyo3:: types:: PyBytes > ,
232402 ) > {
233- let sk_e = pyo3 :: Bound :: new ( py , x25519 :: generate_key ( ) ? ) ?;
403+ let sk_e = self . kem . generate_key ( py ) ?;
234404 let pk_e = sk_e. call_method0 ( pyo3:: intern!( py, "public_key" ) ) ?;
405+ let pk_e_bytes = self . kem . serialize_public_key ( py, & pk_e) ?;
406+ let pk_r_bytes = self . kem . serialize_public_key ( py, pk_r) ?;
235407
236- let pk_e_bytes: pyo3:: Bound < ' p , pyo3:: types:: PyBytes > = pk_e
237- . call_method0 ( pyo3:: intern!( py, "public_bytes_raw" ) ) ?
238- . extract ( ) ?;
239-
240- let pk_r_bytes = pk_r. call_method0 ( pyo3:: intern!( py, "public_bytes_raw" ) ) ?;
241- let pk_r_raw = pk_r_bytes. extract :: < & [ u8 ] > ( ) ?;
242-
243- let dh_result = sk_e. call_method1 ( pyo3:: intern!( py, "exchange" ) , ( pk_r, ) ) ?;
408+ let dh_result = self . kem . exchange ( py, & sk_e, pk_r) ?;
244409 let dh = dh_result. extract :: < & [ u8 ] > ( ) ?;
245410
246- let mut kem_context = [ 0u8 ; 64 ] ;
247- kem_context[ ..32 ] . copy_from_slice ( pk_e_bytes. as_bytes ( ) ) ;
248- kem_context[ 32 ..] . copy_from_slice ( pk_r_raw) ;
411+ let mut kem_context =
412+ Vec :: with_capacity ( pk_e_bytes. as_bytes ( ) . len ( ) + pk_r_bytes. as_bytes ( ) . len ( ) ) ;
413+ kem_context. extend_from_slice ( pk_e_bytes. as_bytes ( ) ) ;
414+ kem_context. extend_from_slice ( pk_r_bytes. as_bytes ( ) ) ;
415+
249416 let shared_secret = self . extract_and_expand ( py, dh, & kem_context) ?;
250417 Ok ( ( shared_secret, pk_e_bytes) )
251418 }
@@ -256,19 +423,18 @@ impl Suite {
256423 enc : & [ u8 ] ,
257424 sk_r : & pyo3:: Bound < ' _ , pyo3:: PyAny > ,
258425 ) -> CryptographyResult < pyo3:: Bound < ' p , pyo3:: types:: PyBytes > > {
259- // Reconstruct pk_e from enc
260- let pk_e = pyo3:: Bound :: new ( py, x25519:: from_public_bytes ( enc) ?) ?;
426+ let pk_e = self . kem . deserialize_public_key ( py, enc) ?;
261427
262- let dh_result = sk_r . call_method1 ( pyo3 :: intern! ( py, "exchange" ) , ( & pk_e, ) ) ?;
428+ let dh_result = self . kem . exchange ( py, sk_r , & pk_e) ?;
263429 let dh = dh_result. extract :: < & [ u8 ] > ( ) ?;
264430
265431 let pk_rm = sk_r. call_method0 ( pyo3:: intern!( py, "public_key" ) ) ?;
266- let pk_rm_bytes = pk_rm. call_method0 ( pyo3:: intern!( py, "public_bytes_raw" ) ) ?;
267- let pk_rm_raw = pk_rm_bytes. extract :: < & [ u8 ] > ( ) ?;
432+ let pk_rm_bytes = self . kem . serialize_public_key ( py, & pk_rm) ?;
433+
434+ let mut kem_context = Vec :: with_capacity ( enc. len ( ) + pk_rm_bytes. as_bytes ( ) . len ( ) ) ;
435+ kem_context. extend_from_slice ( enc) ;
436+ kem_context. extend_from_slice ( pk_rm_bytes. as_bytes ( ) ) ;
268437
269- let mut kem_context = [ 0u8 ; 64 ] ;
270- kem_context[ ..32 ] . copy_from_slice ( enc) ;
271- kem_context[ 32 ..] . copy_from_slice ( pk_rm_raw) ;
272438 self . extract_and_expand ( py, dh, & kem_context)
273439 }
274440
@@ -361,6 +527,7 @@ impl Suite {
361527 info : Option < CffiBuf < ' _ > > ,
362528 aad : Option < CffiBuf < ' _ > > ,
363529 ) -> CryptographyResult < pyo3:: Bound < ' p , pyo3:: types:: PyBytes > > {
530+ self . kem . check_public_key ( py, public_key) ?;
364531 let info_bytes: & [ u8 ] = info. as_ref ( ) . map ( |b| b. as_bytes ( ) ) . unwrap_or ( b"" ) ;
365532
366533 let ( shared_secret, enc) = self . encap ( py, public_key) ?;
@@ -389,14 +556,15 @@ impl Suite {
389556 info : Option < CffiBuf < ' _ > > ,
390557 aad : Option < CffiBuf < ' _ > > ,
391558 ) -> CryptographyResult < pyo3:: Bound < ' p , pyo3:: types:: PyBytes > > {
559+ self . kem . check_private_key ( py, private_key) ?;
392560 let ct_bytes = ciphertext. as_bytes ( ) ;
393- if ct_bytes. len ( ) < kem_params :: X25519_NENC + self . aead . tag_length ( ) {
561+ if ct_bytes. len ( ) < self . kem . enc_length ( ) + self . aead . tag_length ( ) {
394562 return Err ( CryptographyError :: from ( exceptions:: InvalidTag :: new_err ( ( ) ) ) ) ;
395563 }
396564
397565 let info_bytes: & [ u8 ] = info. as_ref ( ) . map ( |b| b. as_bytes ( ) ) . unwrap_or ( b"" ) ;
398566
399- let ( enc, ct) = ct_bytes. split_at ( kem_params :: X25519_NENC ) ;
567+ let ( enc, ct) = ct_bytes. split_at ( self . kem . enc_length ( ) ) ;
400568
401569 let shared_secret = self
402570 . decap ( py, enc, private_key)
@@ -445,20 +613,21 @@ impl Suite {
445613#[ pyo3:: pymethods]
446614impl Suite {
447615 #[ new]
448- fn new ( _kem : KEM , kdf : KDF , aead : AEAD ) -> CryptographyResult < Suite > {
616+ fn new ( kem : KEM , kdf : KDF , aead : AEAD ) -> CryptographyResult < Suite > {
449617 // Build suite IDs
450618 let mut kem_suite_id = [ 0u8 ; 5 ] ;
451619 kem_suite_id[ ..3 ] . copy_from_slice ( b"KEM" ) ;
452- kem_suite_id[ 3 ..] . copy_from_slice ( & kem_params :: X25519_ID . to_be_bytes ( ) ) ;
620+ kem_suite_id[ 3 ..] . copy_from_slice ( & kem . id ( ) . to_be_bytes ( ) ) ;
453621
454622 let mut hpke_suite_id = [ 0u8 ; 10 ] ;
455623 hpke_suite_id[ ..4 ] . copy_from_slice ( b"HPKE" ) ;
456- hpke_suite_id[ 4 ..6 ] . copy_from_slice ( & kem_params :: X25519_ID . to_be_bytes ( ) ) ;
624+ hpke_suite_id[ 4 ..6 ] . copy_from_slice ( & kem . id ( ) . to_be_bytes ( ) ) ;
457625 hpke_suite_id[ 6 ..8 ] . copy_from_slice ( & kdf. id ( ) . to_be_bytes ( ) ) ;
458626 hpke_suite_id[ 8 ..10 ] . copy_from_slice ( & aead. id ( ) . to_be_bytes ( ) ) ;
459627
460628 Ok ( Suite {
461629 aead,
630+ kem,
462631 kem_suite_id,
463632 hpke_suite_id,
464633 kdf,
0 commit comments