Skip to content

Feature/Refactor: TLS Connector #179

@AlisCode

Description

@AlisCode

This is in relation to this issue on reqwless: reqwless#127, specifically the alternative approach.

I suggest adding a TLS Connector to embedded-tls, a type that implements (amongst other future traits) TcpConnect.
Its role is to act as a Tcp connector, but which returns some kind of TlsConnection instead.

The goal is to be able to use this inside of e.g. a reqwless::HttpClient, to be able to decouple the TlsConfig from reqwless itself.
This would also serve to address reqwless#123, as it should not really be a reqwless concern.

I've taken a shot at implementing this, for demonstration purposes. I'm looking for feedback on whether this is a good idea or not.

It's in a bit of a rough shape and the ergonomics might not be the greatest - this is because a TLS Connector has to hold the TLS record buffers so it can create the TLSConnection, and must use interior mutability because TcpConnect::connect gets a &self, while in a TlsConnection, the TLS record buffers must obviously be mutable.

use core::net::SocketAddr;

use embedded_io::{Error, ErrorType};
use embedded_io_async::{Read, Write};
use embedded_nal_async::TcpConnect;
use heapless::pool::boxed::{Box, BoxPool};

use crate::{CryptoProvider, TlsCipherSuite, TlsConfig, TlsConnection, TlsContext, TlsError};

/// A connector that wraps an implementation of a connector C, and establishes a TLS connection
/// using the provided connection.
pub struct TlsConnector<
    'a,
    const RX: usize,
    const TX: usize,
    C,
    CipherSuite,
    MakeProvider,
    Provider: CryptoProvider,
    TlsBuffers,
> {
    /// The underlying TLS Configuration to use when establishing a TLS connection
    config: TlsConfig<'a>,
    /// The connector to use to establish a network connection on which we will perform a TLS
    /// handshake
    connector: C,
    /// A constructor for a type that implements [`CryptoProvider`], which will be used to establish the TLS connection
    make_crypto_provider: MakeProvider,
    _buffers: core::marker::PhantomData<TlsBuffers>,
    _cipher_suite: core::marker::PhantomData<CipherSuite>,
    _provider: core::marker::PhantomData<Provider>,
}

impl<
        'a,
        const RX: usize,
        const TX: usize,
        C,
        CipherSuite,
        MakeProvider,
        Provider: CryptoProvider,
        TlsBuffers,
    > TlsConnector<'a, RX, TX, C, CipherSuite, MakeProvider, Provider, TlsBuffers>
{
    pub fn new(config: TlsConfig<'a>, connector: C, make_crypto_provider: MakeProvider) -> Self {
        TlsConnector {
            config,
            connector,
            make_crypto_provider,
            _buffers: core::marker::PhantomData,
            _cipher_suite: core::marker::PhantomData,
            _provider: core::marker::PhantomData,
        }
    }
}

impl<const RX: usize, const TX: usize, C, CipherSuite, MakeProvider, Provider, TlsBuffers>
    TcpConnect for TlsConnector<'_, RX, TX, C, CipherSuite, MakeProvider, Provider, TlsBuffers>
where
    C: TcpConnect,
    CipherSuite: TlsCipherSuite + 'static,
    MakeProvider: Fn() -> Provider,
    Provider: CryptoProvider<CipherSuite = CipherSuite>,
    TlsBuffers: BoxPool<Data = ([u8; RX], [u8; TX])>,
{
    type Error = TlsError;

    type Connection<'a>
        = PooledTlsConnection<
        'a,
        RX,
        TX,
        C::Connection<'a>,
        <Provider as CryptoProvider>::CipherSuite,
        TlsBuffers,
    >
    where
        Self: 'a;

    async fn connect(&self, remote: SocketAddr) -> Result<Self::Connection<'_>, Self::Error> {
        let tcp_conn = self
            .connector
            .connect(remote)
            .await
            .map_err(|err| TlsError::Io(err.kind()))?;

        let buffers_alloc =
            TlsBuffers::alloc(([0u8; RX], [0u8; TX])).map_err(|_| TlsError::OutOfMemory)?;

        // TlsConnection needs a &mut to the RX/TX buffers
        // This is safe to do because nothing else gets access to the Box.
        let buffers = Box::into_raw(buffers_alloc);
        let (rx, tx) = unsafe { &mut *buffers };
        let mut tls_conn = TlsConnection::new(tcp_conn, rx, tx);

        let crypto_provider = (self.make_crypto_provider)();
        let context = TlsContext::new(&self.config, crypto_provider);
        tls_conn.open(context).await?;

        let tls_record_buffers = unsafe { Box::from_raw(buffers) };

        Ok(PooledTlsConnection {
            conn: tls_conn,
            _tls_record_buffers: tls_record_buffers,
        })
    }
}

pub struct PooledTlsConnection<
    'a,
    const RX: usize,
    const TX: usize,
    Socket: Read + Write + 'a,
    CipherSuite: TlsCipherSuite + 'static,
    TlsBuffers: BoxPool<Data = ([u8; RX], [u8; TX])>,
> {
    conn: TlsConnection<'a, Socket, CipherSuite>,
    // A long-lived heapless Box pointer containing the TLS Record buffers.
   // Its role is to keep the &mut [u8] alive, and drop it when the connection is dropped.
    _tls_record_buffers: Box<TlsBuffers>,
}

impl<
        'a,
        const RX: usize,
        const TX: usize,
        Socket: Read + Write + 'a,
        CipherSuite: TlsCipherSuite + 'static,
        TlsBuffers: BoxPool<Data = ([u8; RX], [u8; TX])>,
    > ErrorType for PooledTlsConnection<'a, RX, TX, Socket, CipherSuite, TlsBuffers>
{
    type Error = <TlsConnection<'a, Socket, CipherSuite> as ErrorType>::Error;
}

impl<
        'a,
        const RX: usize,
        const TX: usize,
        Socket: Read + Write + 'a,
        CipherSuite: TlsCipherSuite + 'static,
        TlsBuffers: BoxPool<Data = ([u8; RX], [u8; TX])>,
    > Read for PooledTlsConnection<'a, RX, TX, Socket, CipherSuite, TlsBuffers>
{
    async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
        self.conn.read(buf).await
    }
}

impl<
        'a,
        const RX: usize,
        const TX: usize,
        Socket: Read + Write + 'a,
        CipherSuite: TlsCipherSuite + 'static,
        TlsBuffers: BoxPool<Data = ([u8; RX], [u8; TX])>,
    > Write for PooledTlsConnection<'a, RX, TX, Socket, CipherSuite, TlsBuffers>
{
    async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
        self.conn.write(buf).await
    }

    async fn flush(&mut self) -> Result<(), Self::Error> {
        self.conn.flush().await
    }
}

Thoughts? Ping @mdelete this is what I mentioned on the Matrix room

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions