Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions aioopenssl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,14 +110,17 @@ class STARTTLSTransport(asyncio.Transport):
certificate validators implementing e.g. DANE.

`server_hostname` must be either a :class:`str` or :data:`None`. It may be
used by certificate validators anrd must be the host name for which the
used by certificate validators and must be the host name for which the
peer must have a valid certificate (if host name based certificate
validation is performed). `server_hostname` is also passed via the TLS
Server Name Indication (SNI) extension if it is given.

If host names are to be converted to :class:`bytes` by the transport, they
are encoded using the ``utf-8` codec.

If `server_mode` is true, TLS will be negotiated as a server. Defaults to
false (client mode).

If `waiter` is not :data:`None`, it must be a
:class:`asyncio.Future`. After the stream has been established, the futures
result is set to a value of :data:`None`. If any errors occur, the
Expand Down Expand Up @@ -145,7 +148,8 @@ def __init__(self, loop, rawsock, protocol, ssl_context_factory,
use_starttls=False,
post_handshake_callback=None,
peer_hostname=None,
server_hostname=None):
server_hostname=None,
server_mode=False):
if not use_starttls and not ssl_context_factory:
raise ValueError("Cannot have STARTTLS disabled (i.e. immediate "
"TLS connection) and without SSL context.")
Expand Down Expand Up @@ -173,7 +177,8 @@ def __init__(self, loop, rawsock, protocol, ssl_context_factory,
ssl_object=None,
peername=self._rawsock.getpeername(),
peer_hostname=peer_hostname,
server_hostname=server_hostname
server_hostname=server_hostname,
server_mode=server_mode
)

# this is a list set of tasks which will also be cancelled if the
Expand Down Expand Up @@ -298,7 +303,11 @@ def _initiate_tls(self):
self._tls_conn = OpenSSL.SSL.Connection(
self._ssl_context,
self._sock)
self._tls_conn.set_connect_state()
# Specify whether this is client or server
if self._extra["server_mode"]:
self._tls_conn.set_accept_state()
else:
self._tls_conn.set_connect_state()
self._tls_conn.set_app_data(self)
try:
self._tls_conn.set_tlsext_host_name(
Expand Down Expand Up @@ -632,6 +641,10 @@ def starttls(self, ssl_context=None,
if post_handshake_callback is not None:
self._tls_post_handshake_callback = post_handshake_callback

# Drain before initializing TLS
while self._buffer:
yield from asyncio.sleep(0)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven’t looked at this into detail yet, so this may be stupid, but yield from self.drain() is no option?

Also, this makes concurrent writes undefined, which should at least be documented.

In addition, a very simple server-mode test (possibly without starttls) should be added. I’m trying to get some unit-test coverage into aioopenssl.


self._waiter = asyncio.Future()
self._waiter.add_done_callback(self._waiter_done)
self._initiate_tls()
Expand Down