Skip to content
31 changes: 27 additions & 4 deletions application/features/Audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@
from .Connection import Connection
from .. import app
from ..utils import find_free_port, get_headers_dict_from_str, local_auth
import logging

logger = logging.getLogger(__name__)

AUDIO_CONNECTIONS = {}

Expand All @@ -57,13 +60,16 @@ def __del__(self):
super().__del__()

def connect(self, *args, **kwargs):
logger.debug("Audio: Establishing Audio connection")
return super().connect(*args, **kwargs)

def launch_audio(self):
try:
logger.debug("Audio: Launching Audio connection.")
self.transport = self.client.get_transport()
self.remote_port = self.transport.request_port_forward('127.0.0.1', 0)
except Exception as e:
logger.exception("Audio: exception raised during launch audio")
return False, str(e)

self.id = uuid.uuid4().hex
Expand All @@ -83,11 +89,12 @@ def handleConnected(self):
headers = get_headers_dict_from_str(headers)
if not local_auth(headers=headers, abort_func=self.close):
# local auth failure
logger.warning("AudioWebSocket: Local Authentication Failure")
return

audio_id = self.request.path[1:]
if audio_id not in AUDIO_CONNECTIONS:
print(f'AudioWebSocket: Requested audio_id={audio_id} does not exist.')
logger.warning("AudioWebSocket: Requested audio_id=%s does not exist", audio_id)
self.close()
return

Expand All @@ -103,26 +110,35 @@ def handleConnected(self):
f'module-null-sink sink_name={sink_name} '
exit_status, _, stdout, _ = self.audio.exec_command_blocking(load_module_command)
if exit_status != 0:
print(f'AudioWebSocket: audio_id={audio_id}: unable to load pactl module-null-sink sink_name={sink_name}')
logger.warning(
"AudioWebSocket: audio_id=%s: unable to load pactl module-null-sink sink_name=%s",
audio_id,
sink_name
)
return
load_module_stdout_lines = stdout.readlines()
logger.debug("AudioWebSocket: Load Module: %s", load_module_stdout_lines)
self.module_id = int(load_module_stdout_lines[0])

keep_launching_ffmpeg = True

def ffmpeg_launcher():
logger.debug("AudioWebSocket: ffmpeg_launcher thread started")
# TODO: support requesting audio format from the client
launch_ffmpeg_command = f'killall ffmpeg; ffmpeg -f pulse -i "{sink_name}.monitor" ' \
f'-ac 2 -acodec pcm_s16le -ar 44100 -f s16le "tcp://127.0.0.1:{self.audio.remote_port}"'
# keep launching if the connection is not accepted in the writer() below
while keep_launching_ffmpeg:
logger.debug("AudioWebSocket: Launch ffmpeg: %s", launch_ffmpeg_command)
_, ffmpeg_stdout, _ = self.audio.client.exec_command(launch_ffmpeg_command)
ffmpeg_stdout.channel.recv_exit_status()
# if `ffmpeg` launches successfully, `ffmpeg_stdout.channel.recv_exit_status` should not return
logger.debug("AudioWebSocket: ffmpeg_launcher thread ended")

ffmpeg_launcher_thread = threading.Thread(target=ffmpeg_launcher)

def writer():
logger.debug("AudioWebSocket: writer thread started")
channel = self.audio.transport.accept(FFMPEG_LOAD_TIME * TRY_FFMPEG_MAX_COUNT)

nonlocal keep_launching_ffmpeg
Expand All @@ -138,14 +154,17 @@ def writer():
while True:
data = channel.recv(AUDIO_BUFFER_SIZE)
if not data:
logger.debug("AudioWebSocket: Close audio socket connection")
self.close()
break
buffer += data
if len(buffer) >= AUDIO_BUFFER_SIZE:
compressed = zlib.compress(buffer, level=4)
logger.debug("AudioWebSocket: Send compressed message of size %d", compressed)
self.sendMessage(compressed)
# print(len(compressed) / len(buffer) * 100)
logger.debug("Audio: Audio port %s", AUDIO_PORT)
buffer = b''
logger.debug("AudioWebSocket: write thread ended")

writer_thread = threading.Thread(target=writer)

Expand All @@ -155,8 +174,10 @@ def writer():
def handleClose(self):
if self.module_id is not None:
# unload the module before leaving
logger.debug("AudioWebSocket: Unload module %d", self.module_id)
self.audio.client.exec_command(f'pactl unload-module {self.module_id}')

logger.debug("AudioWebSocket: End audio socket %s connection", self.audio.id)
del AUDIO_CONNECTIONS[self.audio.id]
del self.audio

Expand All @@ -166,13 +187,15 @@ def handleClose(self):
# if we are in debug mode, run the server in the second round
if not app.debug or os.environ.get("WERKZEUG_RUN_MAIN") == "true":
AUDIO_PORT = find_free_port()
print("AUDIO_PORT =", AUDIO_PORT)
logger.debug("Audio: Audio port %s", AUDIO_PORT)

if os.environ.get('SSL_CERT_PATH') is None:
logger.debug("Audio: SSL Certification Path not set. Generating self-signing certificate")
# no certificate provided, generate self-signing certificate
audio_server = SimpleSSLWebSocketServer('127.0.0.1', AUDIO_PORT, AudioWebSocket,
ssl_context=generate_adhoc_ssl_context())
else:
logger.debug("Audio: SSL Certification Path exists")
import ssl

audio_server = SimpleSSLWebSocketServer('0.0.0.0', AUDIO_PORT, AudioWebSocket,
Expand Down
47 changes: 36 additions & 11 deletions application/features/Connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,27 +27,30 @@
import paramiko
import select

import logging

logger = logging.getLogger(__name__)

class ForwardServerHandler(socketserver.BaseRequestHandler):
def handle(self):
self.server: ForwardServer
try:
logger.debug("Connection: Open forward server channel")
chan = self.server.ssh_transport.open_channel(
"direct-tcpip",
("127.0.0.1", self.server.chain_port),
self.request.getpeername(),
)
except Exception as e:
logger.exception("Connection: Incoming request to 127.0.0.1:%d failed", self.server.chain_port)
return False, "Incoming request to %s:%d failed: %s" % (
"127.0.0.1", self.server.chain_port, repr(e))

print(
"Connected! Tunnel open %r -> %r -> %r"
% (
self.request.getpeername(),
chan.getpeername(),
("127.0.0.1", self.server.chain_port),
)
logger.info(
"Connected! Tunnel open %r -> %r -> %r",
self.request.getpeername(),
chan.getpeername(),
("127.0.0.1", self.server.chain_port),
)

try:
Expand All @@ -64,13 +67,15 @@ def handle(self):
break
self.request.send(data)
except Exception as e:
print(e)
logger.exception("Connection: Error occurred during data transfer")

try:
logger.debug("Connection: Close forward server channel")
chan.close()
self.server.shutdown()
except Exception as e:
print(e)
logger.exception("Connection: Close forward server channel failed")


class ForwardServer(socketserver.ThreadingTCPServer):
Expand Down Expand Up @@ -102,6 +107,9 @@ def __del__(self):
def _client_connect(self, client: paramiko.SSHClient,
host, username,
password=None, key_filename=None, private_key_str=None):
if self._jump_channel is not None:
logger.debug("Connection: Connection initialized through Jump Channel")
logger.debug("Connection: Connecting to %s@%s", username, host)
if password is not None:
client.connect(host, username=username, password=password, timeout=15, sock=self._jump_channel)
elif key_filename is not None:
Expand All @@ -128,23 +136,26 @@ def _init_jump_channel(self, host, username, **auth_methods):

self._jump_client = paramiko.SSHClient()
self._jump_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
logger.debug("Connection: Initialize Jump Client for connection to %s@remote.ecf.utoronto.ca", username)
self._client_connect(self._jump_client, 'remote.ecf.utoronto.ca', username, **auth_methods)
logger.debug("Connection: Open Jump channel connection to %s at port 22", host)
self._jump_channel = self._jump_client.get_transport().open_channel('direct-tcpip',
(host, 22),
('127.0.0.1', 22))

def connect(self, host: str, username: str, **auth_methods):
try:
logger.debug("Connection: Connection attempt to %s@%s", username, host)
self._init_jump_channel(host, username, **auth_methods)
self._client_connect(self.client, host, username, **auth_methods)
except Exception as e:
# raise e
# print('Connection::connect() exception:')
logger.exception("Connection: Connection attempt to %s@%s failed", username, host)
return False, str(e)

self.host = host
self.username = username

logger.debug("Connection: Successfully connected to %s@%s", username, host)
return True, ''

@staticmethod
Expand All @@ -160,9 +171,11 @@ def ssh_keygen(key_filename=None, key_file_obj=None, public_key_comment=''):

# save the private key
if key_filename is not None:
logger.debug("Connection: RSA SSH private key written to %s", key_filename)
rsa_key.write_private_key_file(key_filename)
elif key_file_obj is not None:
rsa_key.write_private_key(key_file_obj)
logger.debug("Connection: RSA SSH private key written to %s", key_file_obj)
else:
raise ValueError('Neither key_filename nor key_file_obj is provided.')

Expand Down Expand Up @@ -192,6 +205,7 @@ def save_keys(self, key_filename=None, key_file_obj=None, public_key_comment='')
"mkdir -p ~/.ssh && chmod 700 ~/.ssh && echo '%s' >> ~/.ssh/authorized_keys" % pub_key)
if exit_status != 0:
return False, "Connection::save_keys: unable to save public key; Check for disk quota and permissions with any conventional SSH clients. "
logger.debug("Connection: Public ssh key saved to remove server ~/.ssh/authorized_keys")

return True, ""

Expand All @@ -217,22 +231,30 @@ def exec_command_blocking_large(self, command):
return '\n'.join(stdout) + '\n' + '\n'.join(stderr)

def _port_forward_thread(self, local_port, remote_port):
logger.debug("Connection: Port forward thread started")
forward_server = ForwardServer(("", local_port), ForwardServerHandler)

forward_server.ssh_transport = self.client.get_transport()
forward_server.chain_port = remote_port

forward_server.serve_forever()
forward_server.server_close()
logger.debug("Connection: Port forward thread ended")

def port_forward(self, *args):
forwarding_thread = threading.Thread(target=self._port_forward_thread, args=args)
forwarding_thread.start()

def is_eecg(self):
return 'eecg' in self.host
if 'eecg' in self.host:
logger.debug("Connection: Target host is eecg")
return True

return False

def is_ecf(self):
if 'ecf' in self.host:
logger.debug("Connection: Target host is ecf")
return 'ecf' in self.host

def is_uoft(self):
Expand All @@ -256,6 +278,9 @@ def is_load_high(self):

my_pts_count = len(output) - 1 # -1: excluding the `uptime` output

logger.debug("Connection: pts count: %d; my pts count: %d", pts_count, my_pts_count)
logger.debug("Connection: load sum: %d", load_sum)

if pts_count > my_pts_count: # there are more terminals than mine
return True
elif load_sum > 1.0:
Expand Down
Loading
Loading