-
Notifications
You must be signed in to change notification settings - Fork 53
Description
This is in relation to this issue on reqwless: reqwless#127, specifically the alternative approach.
I suggest adding a TLS Connector to embedded-tls, a type that implements (amongst other future traits) TcpConnect.
Its role is to act as a Tcp connector, but which returns some kind of TlsConnection instead.
The goal is to be able to use this inside of e.g. a reqwless::HttpClient, to be able to decouple the TlsConfig from reqwless itself.
This would also serve to address reqwless#123, as it should not really be a reqwless concern.
I've taken a shot at implementing this, for demonstration purposes. I'm looking for feedback on whether this is a good idea or not.
It's in a bit of a rough shape and the ergonomics might not be the greatest - this is because a TLS Connector has to hold the TLS record buffers so it can create the TLSConnection, and must use interior mutability because TcpConnect::connect gets a &self, while in a TlsConnection, the TLS record buffers must obviously be mutable.
use core::net::SocketAddr;
use embedded_io::{Error, ErrorType};
use embedded_io_async::{Read, Write};
use embedded_nal_async::TcpConnect;
use heapless::pool::boxed::{Box, BoxPool};
use crate::{CryptoProvider, TlsCipherSuite, TlsConfig, TlsConnection, TlsContext, TlsError};
/// A connector that wraps an implementation of a connector C, and establishes a TLS connection
/// using the provided connection.
pub struct TlsConnector<
'a,
const RX: usize,
const TX: usize,
C,
CipherSuite,
MakeProvider,
Provider: CryptoProvider,
TlsBuffers,
> {
/// The underlying TLS Configuration to use when establishing a TLS connection
config: TlsConfig<'a>,
/// The connector to use to establish a network connection on which we will perform a TLS
/// handshake
connector: C,
/// A constructor for a type that implements [`CryptoProvider`], which will be used to establish the TLS connection
make_crypto_provider: MakeProvider,
_buffers: core::marker::PhantomData<TlsBuffers>,
_cipher_suite: core::marker::PhantomData<CipherSuite>,
_provider: core::marker::PhantomData<Provider>,
}
impl<
'a,
const RX: usize,
const TX: usize,
C,
CipherSuite,
MakeProvider,
Provider: CryptoProvider,
TlsBuffers,
> TlsConnector<'a, RX, TX, C, CipherSuite, MakeProvider, Provider, TlsBuffers>
{
pub fn new(config: TlsConfig<'a>, connector: C, make_crypto_provider: MakeProvider) -> Self {
TlsConnector {
config,
connector,
make_crypto_provider,
_buffers: core::marker::PhantomData,
_cipher_suite: core::marker::PhantomData,
_provider: core::marker::PhantomData,
}
}
}
impl<const RX: usize, const TX: usize, C, CipherSuite, MakeProvider, Provider, TlsBuffers>
TcpConnect for TlsConnector<'_, RX, TX, C, CipherSuite, MakeProvider, Provider, TlsBuffers>
where
C: TcpConnect,
CipherSuite: TlsCipherSuite + 'static,
MakeProvider: Fn() -> Provider,
Provider: CryptoProvider<CipherSuite = CipherSuite>,
TlsBuffers: BoxPool<Data = ([u8; RX], [u8; TX])>,
{
type Error = TlsError;
type Connection<'a>
= PooledTlsConnection<
'a,
RX,
TX,
C::Connection<'a>,
<Provider as CryptoProvider>::CipherSuite,
TlsBuffers,
>
where
Self: 'a;
async fn connect(&self, remote: SocketAddr) -> Result<Self::Connection<'_>, Self::Error> {
let tcp_conn = self
.connector
.connect(remote)
.await
.map_err(|err| TlsError::Io(err.kind()))?;
let buffers_alloc =
TlsBuffers::alloc(([0u8; RX], [0u8; TX])).map_err(|_| TlsError::OutOfMemory)?;
// TlsConnection needs a &mut to the RX/TX buffers
// This is safe to do because nothing else gets access to the Box.
let buffers = Box::into_raw(buffers_alloc);
let (rx, tx) = unsafe { &mut *buffers };
let mut tls_conn = TlsConnection::new(tcp_conn, rx, tx);
let crypto_provider = (self.make_crypto_provider)();
let context = TlsContext::new(&self.config, crypto_provider);
tls_conn.open(context).await?;
let tls_record_buffers = unsafe { Box::from_raw(buffers) };
Ok(PooledTlsConnection {
conn: tls_conn,
_tls_record_buffers: tls_record_buffers,
})
}
}
pub struct PooledTlsConnection<
'a,
const RX: usize,
const TX: usize,
Socket: Read + Write + 'a,
CipherSuite: TlsCipherSuite + 'static,
TlsBuffers: BoxPool<Data = ([u8; RX], [u8; TX])>,
> {
conn: TlsConnection<'a, Socket, CipherSuite>,
// A long-lived heapless Box pointer containing the TLS Record buffers.
// Its role is to keep the &mut [u8] alive, and drop it when the connection is dropped.
_tls_record_buffers: Box<TlsBuffers>,
}
impl<
'a,
const RX: usize,
const TX: usize,
Socket: Read + Write + 'a,
CipherSuite: TlsCipherSuite + 'static,
TlsBuffers: BoxPool<Data = ([u8; RX], [u8; TX])>,
> ErrorType for PooledTlsConnection<'a, RX, TX, Socket, CipherSuite, TlsBuffers>
{
type Error = <TlsConnection<'a, Socket, CipherSuite> as ErrorType>::Error;
}
impl<
'a,
const RX: usize,
const TX: usize,
Socket: Read + Write + 'a,
CipherSuite: TlsCipherSuite + 'static,
TlsBuffers: BoxPool<Data = ([u8; RX], [u8; TX])>,
> Read for PooledTlsConnection<'a, RX, TX, Socket, CipherSuite, TlsBuffers>
{
async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
self.conn.read(buf).await
}
}
impl<
'a,
const RX: usize,
const TX: usize,
Socket: Read + Write + 'a,
CipherSuite: TlsCipherSuite + 'static,
TlsBuffers: BoxPool<Data = ([u8; RX], [u8; TX])>,
> Write for PooledTlsConnection<'a, RX, TX, Socket, CipherSuite, TlsBuffers>
{
async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
self.conn.write(buf).await
}
async fn flush(&mut self) -> Result<(), Self::Error> {
self.conn.flush().await
}
}Thoughts? Ping @mdelete this is what I mentioned on the Matrix room