diff --git a/README.md b/README.md index 1393880..3a80607 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,62 @@ # tnet -A file transfer (and messages) tool for local network \ No newline at end of file +Purpose +======= +A minimalist and hackable file transfer (and messages) tool for local network + +Description +=========== ++ Built with python's standard library with no 3rd-party dependencies. + ++ Utilizes the sockets and selectors modules for networking. + ++ Due to the use of the selectors module the app is async and single-threaded. + ++ Requires 2 ports to be opened (can be configured which ones). + ++ Does not preserve file attributes when transferring a file. + ++ Designed for the local network only, thus no encryption is implemented. + ++ Tested with python 3.8, 3.11 and 3.12 on Ubuntu 20.04, Debian 12 and Alpine 3.19 + ++ Licensed under **GPLv3** license + +The architecture looks as follows: +there are 4 sockets - one for listening incoming TCP connections, one for listening incoming UDP connections (discovery mechanism), +the other 2 are UNIX sockets - one for accepting the local commands (send file, send notification, discover who is online, etc.) and the second one is +for distributing events like incoming notifications, file transfer progress (both incoming and outcoming), errors, etc. +First 2 sockets are mandatory and they are used by a daemon for communication over the local network, the rest is optional and used for communication +between cli and daemon on the same machine. + +There are only 4 main source files: ``protocol.py``, ``daemon.py``, ``cli.py`` and ``util.py``, their descriptions are in the corresponding files. + +Typical workflow: +1) daemon is started and listens for the configured TCP and UDP ports, and one local UNIX socket for incoming commands +2) a separate cli instance sends the 'discover' command to the UNIX socket to get a list of available machines in the local network +3) the daemon in turn gets this command and broadcasts the 'discover' request via UDP +4) a daemon on another machine (if any) gets this request on the UDP socket it listens to and replies directly to the requester via TCP with its IPv4 and host name +5) the daemon on the first machine gets it and replies to the cli via the same UNIX socket +6) then this info can be used via cli to send a file or notification, again the command is sent to the daemon via the UNIX socket +7) daemon sends the file or notification via TCP, and optionally can report the progress (or errors) back to ANOTHER UNIX socket +8) counterparty daemon processes the incoming data over TCP and also can optionally report the progress on a UNIX socket, called events socket +9) once file is received it is moved from ``/tmp`` to the configured downloads directory, if transfer fails - the file gets removed from ``/tmp`` + +Configuration +============= +The following env variables can be set for **daemon**: ++``EVENT_SOCKET``, a path to the UNIX socket the events should be reported to, defaults to ``'./run/tnet/events'`` ++``COMMAND_SOKET``, a path to the UNIX socket the commands should be read from, defaults to ``'./run/tnet/commands'`` ++``BROADCAST_PORT``, 'discovery' UDP port each daemon listens to, defaults to ``0xC0DE`` ++``RECEIVE_PORT``, main data transfer TCP port, defaults to ``0xC00D`` ++``FILE_SAVE_DIR``, a directory the downloaded files will be stored, defaults to ``'Downloads'`` ++``REPORT_PROGRESS_IN``, 0 or 1, if a CLI or other app needs to track progress via EVENT_SOCKET of **incomming** file transfer, defaults to 0 ++``REPORT_PROGRESS_OUT``, 0 or 1, if a CLI or other app needs to track progress via EVENT_SOCKET of **outcoming** file transfer defaults to 0 + +The following env variables can be set for **cli**: ++``EVENT_SOCKET``, a path to the UNIX socket the events can be read from, defaults to ``'./run/tnet/events'`` ++``COMMAND_SOKET``, a path to the UNIX socket the commands should be sent to, defaults to ``'./run/tnet/commands'`` + + + + diff --git a/src/cli.py b/src/cli.py new file mode 100755 index 0000000..1c961af --- /dev/null +++ b/src/cli.py @@ -0,0 +1,234 @@ +#!/usr/bin/env python3 + +# Can be run in three modes: first mode sends commands to the daemon, like 'file transfer', 'notification', etc. +# Obviously a target host should be specified, where the second 'discovery' mode comes in handy +# The third mode is a 'listening' mode where it listens to incoming notifications, progress info, errors, etc. +# Required 2 UNIX sockets to be configured - one for the commands, second for listening to incoming messages +# A custom implementation of Pattern can be used to add some more sophisticated logic for incoming data + +import os +import sys +import argparse +import selectors +import traceback + +from pathlib import Path + +from protocol import Pattern +from protocol import Missive +from protocol import NoReply +from daemon import Env +from daemon import Defaults +from util import accept_connection +from util import send_via_unix +from util import register_unix_handler + +_COMMAND_SOKET = None +_EVENT_SOCKET = None + +LINE_UP = '\033[1A' +LINE_CLEAR = '\x1b[2K' + + +class PrintOutDiscovery(Pattern): + """Client pattern for printing discovery info.""" + + def onprogress(self, handler, missive): + pass + + def onsent(self, handler, missive): + pass + + def onreceived(self, handler, missive): + for k, v in missive.data.items(): + print(f'{k}\t{v}') + if handler.canclose: + handler.close() + + +class PrintOutEvent(Pattern): + """Client pattern for printing events in console.""" + + def onprogress(self, handler, missive): + pass + + def onsent(self, handler, missive): + pass + + def onreceived(self, handler, missive): + if handler.canclose: + handler.close() + action = missive.data.get('action', None) + if not action: + return + elif action == 'onprogress': + total = missive.data.get('total', 1) + now = missive.data.get('processed', 0) + fname = missive.data.get('fname', 'unknown') + way = missive.data.get('way', '').upper() + print(f'{way} {fname} {now * 100 / total:.2f}%') + print(LINE_UP, end=LINE_CLEAR) + elif action == 'onnotify': + notification = missive.data.get('txt') + print(notification) + elif action == 'onerror': + notification = missive.data.get('txt') + print(f'ERROR {notification}') + else: + pass + + +def get_run_options() -> dict: + result = dict(file='', message='', target='') + parser = argparse.ArgumentParser( + prog='tnet', + description='Sends files or messages to another machine') + parser.add_argument('--file', + '-f', + type=str, + help='Full(!) path to a file to be sent', + required=False) + parser.add_argument('--message', + '-m', + type=str, + help='A message to be sent', + required=False) + parser.add_argument('--target', + '-t', + type=str, + help='Local IP address of a target machine (IPv4)', + required=False) + parser.add_argument('--discover', + '-d', + action='store_true', + help='Checks local network for listening ' + + 'daemons on other machines', + required=False) + parser.add_argument('--listen', + '-l', + action='store_true', + help='Start listening a UNIX socket for ' + + 'events from the local daemon', + required=False) + args = parser.parse_args() + result['listen'] = args.listen + result['discover'] = args.discover + result['target'] = args.target + if not args.file and not args.message: + result['message'] = '' + else: + result['file'] = args.file + result['message'] = args.message + + return result + + +def loop(sel) -> int: + exit_code = 0 + try: + while True: + events = sel.select(timeout=None) + for key, mask in events: + handler = key.data + try: + handler.handle_events(mask) + except Exception as ex: + exit_code = 1 + print( + f'Exception for {handler.addr}, {ex}:\n' + f'{traceback.format_exc()}', + file=sys.stderr) + handler.close() + if not sel.get_map(): + break + except KeyboardInterrupt: + exit_code = 1 + print('Caught keyboard interrupt, exiting', file=sys.stderr) + finally: + sel.close() + return exit_code + + +def recycle_old_sockets(): + """Sockets will be created on bind(), however the directory tree must exist. + + Old files have to be removed to prevent data corruption. + """ + global _EVENT_SOCKET + evt_file = Path(_EVENT_SOCKET) + evt_file.parent.mkdir(parents=True, exist_ok=True) + evt_file.unlink(missing_ok=True) + + +def as_events_listener(sel): + global _EVENT_SOCKET + _EVENT_SOCKET = os.environ.get(Env.EVENT_SOCKET, Defaults.EVENT_SOCKET) + recycle_old_sockets() + register_unix_handler(sel, _EVENT_SOCKET) + try: + while True: + events = sel.select(timeout=None) + for key, mask in events: + if key.data is None: + accept_connection(sel, key.fileobj, PrintOutEvent()) + else: + try: + key.data.handle_events(mask) + except Exception: + print( + f'Handling failed, connection will be closed: \n' + f'{traceback.format_exc()}', + file=sys.stderr) + key.data.close() + except KeyboardInterrupt: + print('Caught interrupt, exiting', file=sys.stderr) + finally: + print('Closing selector', file=sys.stderr) + sel.close() + + +if __name__ == '__main__': + _COMMAND_SOKET = os.environ.get(Env.COMMAND_SOKET, Defaults.COMMAND_SOKET) + + if not Path(_COMMAND_SOKET).exists(): + sys.exit(f'socket file {_COMMAND_SOKET} not found') + params = get_run_options() + payloads = [] + listen_mode = params['listen'] + discovery_mode = params['discover'] + filesend_mode = params['file'] and params['target'] and Path(params['file']).exists() + notification_mode = params['message'] and params['target'] + + sel = selectors.DefaultSelector() + if listen_mode: + sys.exit(as_events_listener(sel)) + elif discovery_mode: + print('appending a discover command') + data = dict(action='discover') + payloads.append(Missive(data)) + elif filesend_mode: + print(f'appending a file {params["file"]}') + data = dict(action='filetransfer', + filepath=params['file'], + target=params['target']) + payloads.append(Missive(data)) + elif notification_mode: + print(f'appending a msg {params["message"]}') + data = dict(action='notify', + value=params['message'], + target=params['target']) + payloads.append(Missive(data)) + else: + sys.exit('try to run it with the "--help" option first') + + try: + if discovery_mode: + send_via_unix(sel, _COMMAND_SOKET, PrintOutDiscovery(), *payloads) + else: + send_via_unix(sel, _COMMAND_SOKET, NoReply(), *payloads) + except ConnectionRefusedError: + sys.exit('Connection refused via UNIX socket, check if local daemon is running') + except Exception as e: + sys.exit(f'Connection failed via UNIX socket: {e}') + else: + sys.exit(loop(sel)) diff --git a/src/daemon.py b/src/daemon.py new file mode 100755 index 0000000..faefc1c --- /dev/null +++ b/src/daemon.py @@ -0,0 +1,410 @@ +#!/usr/bin/env python3 + +# Lisnens for incoming connections using the selectors module, once an connections is established, +# executes some action like accepting a file transfer, relaying a notification to the dedicated UNIX socket, +# replying with the host name and its IPv4 address, etc. +# While sending a file, the progress info can be optionally posted to the dedicated UNIX socket (called events socket) +# for further processing by an external app +# Class called LordCommander plays a role of an orchestrator and contains all the main logic, and also is a Pattern itself +# Due to the use of the selectors module, the daemon is asyncronos and single-threaded + +import os +import sys +import socket +import selectors +import traceback +import signal +import shutil + +from pathlib import Path + +from protocol import Pattern +from protocol import Missive +from protocol import Parcel +from protocol import NoReply + +from util import register_tcp_handler +from util import register_unix_handler +from util import register_udp_handler +from util import send_via_tcp +from util import broadcast_via_udp +from util import send_via_unix +from util import accept_connection + +_COMMAND_SOCKET = None +_EVENT_SOCKET = None +_BROACAST_PORT = None +_RECEIVE_PORT = None +_FILE_SAVE_DIR = None + + +class SupportedActions: + DISCOVER = 'discover' + BROADCAST = 'broadcast' + NOTIFY = 'notify' + QUERY = 'query' + FILETRANSFER = 'filetransfer' + EXPOSE = 'expose' + EXPOSED = 'exposed' + + +class Env: + EVENT_SOCKET = 'EVENT_SOCKET' + COMMAND_SOKET = 'EVENT_SOCKET' + BROADCAST_PORT = 'BROADCAST_PORT' + RECEIVE_PORT = 'RECEIVE_PORT' + FILE_SAVE_DIR = 'FILE_SAVE_DIR' + REPORT_PROGRESS_IN = 'REPORT_PROGRESS_IN' + REPORT_PROGRESS_OUT = 'REPORT_PROGRESS_OUT' + + +class Defaults: + EVENT_SOCKET = './run/tnet/events' + COMMAND_SOKET = './run/tnet/commands' + BROADCAST_PORT = 0xC0DE + RECEIVE_PORT = 0xC00D + FILE_SAVE_DIR = 'Downloads' + REPORT_PROGRESS_IN = 0 + REPORT_PROGRESS_OUT = 0 + + +class LordCommander(Pattern): + """Server pattern.""" + + def __init__(self, *, lan_ip, broadcast_lan_ip, main_selector, + broadcast_port, receive_port, + file_save_dir, event_socket_path, + report_incoming_progress, report_outcoming_progress): + self.lan_ip = lan_ip + self.broadcast_lan_ip = broadcast_lan_ip + self.main_selector = main_selector + self.broadcast_port = broadcast_port + self.receive_port = receive_port + self.file_save_dir = file_save_dir + self.event_socket_path = event_socket_path + self.report_incoming_progress = report_incoming_progress + self.report_outcoming_progress = report_outcoming_progress + self._actions = { + SupportedActions.DISCOVER: self._discover, + SupportedActions.BROADCAST: self._broadcast, + SupportedActions.NOTIFY: self._notify, + SupportedActions.QUERY: self._query, + SupportedActions.FILETRANSFER: self._filetransfer, + SupportedActions.EXPOSE: self._expose, + SupportedActions.EXPOSED: self._exposed + } + self._discovered = {} + self._progress = {} + + def onprogress(self, handler, missive): + if not self.event_socket_path: + return + if not self.report_incoming_progress: + return + + f_name = missive.jsonheader.get('filename', 'unknown') + total = (missive.jsonheader.get('content-length', 1) + if missive.jsonheader else 1) + processed = (missive.payload_processed_b + if missive.payload_processed_b else 0) + pct = int(processed * 100 / total) + if missive.missive_id not in self._progress and pct != 100: + self._progress[missive.missive_id] = pct + if self._progress[missive.missive_id] == pct: + return + else: + self._progress[missive.missive_id] = pct + + try: + send_via_unix(self.main_selector, + self.event_socket_path, + NoReply(), + Missive(dict(action='onprogress', + total=total, + processed=processed, + way='in', + fname=f_name, + mid=str(missive.missive_id)))) + except (FileNotFoundError, ConnectionRefusedError): + pass + + if pct == 100: + del self._progress[missive.missive_id] + + def onsent(self, handler, missive): + if handler.canclose: + handler.close() + if missive.error: + self._report_error(handler, missive) + + def onreceived(self, handler, missive): + if missive.error: + self._report_error(handler, missive) + elif missive.jsonheader['content-type'] == 'text/json': + action = missive.data.get('action') + if action is None or action not in self._actions: + self._notify(handler, missive) + else: + self._actions.get(action)(handler, missive) + elif missive.jsonheader['content-type'] == 'file': + if handler.canclose: + handler.close() + orig_name = missive.jsonheader['filename'] + curr_name = missive.data + Path(self.file_save_dir).mkdir(parents=True, exist_ok=True) + shutil.move(curr_name, + Path(self.file_save_dir) / orig_name) + else: + self._notify(handler, missive) + + def _filetransfer(self, handler, missive): + if handler.canclose: + handler.close() + target = missive.data.get('target', '127.0.0.1') + path = Path(missive.data.get('filepath', '')) + if not path.is_file(): + print(f'File {path} not found (ensure "filepath" is provided)', + file=sys.stderr) + return + if target in ['localhost', '127.0.0.1', self.lan_ip]: + shutil.copy2(path, path.home / self.file_save_dir / path.name) + else: + pattern = (ReportSendProgress(self.main_selector, self.event_socket_path) + if self.report_outcoming_progress else NoReply()) + send_via_tcp(self.main_selector, target, self.receive_port, + pattern, + Parcel(path)) + + def _discover(self, handler, missive): + broadcast_via_udp(self.main_selector, self.broadcast_lan_ip, + self.broadcast_port, Missive(dict( + action=SupportedActions.EXPOSE, + replyto=self.lan_ip))) + # write back whatever we have at the moment (can be nothing) + handler.enqueue(Missive(self._discovered)) + handler.set_selector_events_mask('w') + + def _expose(self, handler, missive): + print(missive) + target = missive.data.get('replyto', None) + if target == self.lan_ip: + return + payload = dict(action=SupportedActions.EXPOSED, + ip=self.lan_ip, + host=socket.gethostname()) + send_via_tcp(self.main_selector, target, self.receive_port, + NoReply(), Missive(payload)) + + def _exposed(self, handler, missive): + if handler.canclose: + handler.close() + ip = missive.data.get('ip', None) + name = missive.data.get('host', None) + if ip: + self._discovered[ip] = name + + def _notify(self, handler, missive): + if handler.canclose: + handler.close() + if isinstance(missive.data, dict): + payload = missive.data.get('value', missive.data) + target = missive.data.get('target', None) + if target and payload: + send_via_tcp(self.main_selector, target, self.receive_port, + NoReply(), Missive(payload)) + else: + self._accept_notification(payload) + else: + self._accept_notification(missive.data) + + def _accept_notification(self, notification): + if not self.event_socket_path: + return + + try: + send_via_unix(self.main_selector, + self.event_socket_path, + NoReply(), + Missive(dict(action='onnotify', + txt=notification))) + except (FileNotFoundError, ConnectionRefusedError): + pass + + def _report_error(self, handler, missive): + if handler.canclose: + handler.close() + + if missive.jsonheader['content-type'] == 'file' and missive.data: + Path(missive.data).unlink() + + if not self.event_socket_path: + return + try: + send_via_unix(self.main_selector, + self.event_socket_path, + NoReply(), + Missive(dict(action='onerror', + txt=missive.error))) + except (FileNotFoundError, ConnectionRefusedError): + pass + + def _broadcast(self, handler, missive): + pass + + def _query(self, handler, missive): + payload = missive.data.get('value') + if not payload: + handler.close() + else: + handler.enqueue(Missive(f'replying to: {payload}')) + handler.set_selector_events_mask('w') + + +class ReportSendProgress(Pattern): + """Pattern that writes progress status to event socket.""" + + def __init__(self, selector, event_socket_path): + self._selector = selector + self._socket_path = event_socket_path + self._sent_pct = 0 + + def onprogress(self, handler, missive): + if not self._socket_path: + return + if not isinstance(missive, Parcel): + return + + total = missive.size if missive.size else 1 + processed = (missive.payload_processed_b + if missive.payload_processed_b else 0) + pct = int(processed * 100 / total) + if pct == self._sent_pct: + return + else: + self._sent_pct = pct + + try: + send_via_unix(self._selector, + self._socket_path, + NoReply(), + Missive(dict(action='onprogress', + total=total, + processed=processed, + way='out', + fname=missive.filename, + mid=str(missive.missive_id)))) + except (FileNotFoundError, ConnectionRefusedError): + pass + + def onsent(self, handler, missive): + if handler.canclose: + handler.close() + + if not self._socket_path or not isinstance(missive, Parcel): + return + notification = {} + if missive.error: + notification = dict(action='onerror', + txt=f'{missive.path}: {missive.error}') + else: + notification = dict(action='onnotify', + txt=f'{missive.path} has been sent') + try: + send_via_unix(self._selector, + self._socket_path, + NoReply(), + Missive(notification)) + except (FileNotFoundError, ConnectionRefusedError): + pass + + def onreceived(self, handler, missive): + pass + + +def signal_handler(sig, frame): + raise KeyboardInterrupt + + +def get_lan_ips(): + with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sock: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1) + sock.connect(('', 65535)) + lan_ip = sock.getsockname()[0] + chunks = lan_ip.split('.') + chunks[3] = '255' + lan_broadcast_ip = '.'.join(chunks) + return (lan_ip, lan_broadcast_ip) + + +def recycle_old_sockets(): + """Sockets will be created on bind(), however the directory tree must exist. + + Old files have to be removed to prevent data corruption. + """ + global _COMMAND_SOCKET + cmd_file = Path(_COMMAND_SOCKET) + cmd_file.parent.mkdir(parents=True, exist_ok=True) + cmd_file.unlink(missing_ok=True) + + +if __name__ == '__main__': + + # register handlers for INTERRUPT and TERMINATE signals + # SIGKILL cannot be handled + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) + + _COMMAND_SOCKET = os.environ.get(Env.COMMAND_SOKET, Defaults.COMMAND_SOKET) + _EVENT_SOCKET = os.environ.get(Env.EVENT_SOCKET, Defaults.EVENT_SOCKET) + _BROACAST_PORT = os.environ.get(Env.BROADCAST_PORT, + Defaults.BROADCAST_PORT) + _RECEIVE_PORT = os.environ.get(Env.RECEIVE_PORT, Defaults.RECEIVE_PORT) + _FILE_SAVE_DIR = os.environ.get(Env.FILE_SAVE_DIR, Defaults.FILE_SAVE_DIR) + _REPORT_PROGRESS_IN = os.environ.get(Env.REPORT_PROGRESS_IN, + Defaults.REPORT_PROGRESS_IN) + _REPORT_PROGRESS_OUT = os.environ.get(Env.REPORT_PROGRESS_OUT, + Defaults.REPORT_PROGRESS_OUT) + + recycle_old_sockets() + + lan_ip, broadcast_lan_ip = get_lan_ips() + sel = selectors.DefaultSelector() + + # a bunch of listeners to be used for communication + udp_handler = register_udp_handler(sel, broadcast_lan_ip, _BROACAST_PORT) + register_unix_handler(sel, _COMMAND_SOCKET) + register_tcp_handler(sel, lan_ip, _RECEIVE_PORT) + + orchestrator = LordCommander( + lan_ip=lan_ip, + broadcast_lan_ip=broadcast_lan_ip, + main_selector=sel, + broadcast_port=_BROACAST_PORT, + receive_port=_RECEIVE_PORT, + file_save_dir=_FILE_SAVE_DIR, + event_socket_path=_EVENT_SOCKET, + report_incoming_progress=_REPORT_PROGRESS_IN, + report_outcoming_progress=_REPORT_PROGRESS_OUT + ) + + udp_handler.setpattern(orchestrator) + + try: + while True: + events = sel.select(timeout=None) + for key, mask in events: + if key.data is None: + accept_connection(sel, key.fileobj, orchestrator) + else: + try: + key.data.handle_events(mask) + except Exception: + print( + f'Handling failed, connection will be closed: \n' + f'{traceback.format_exc()}', + file=sys.stderr) + key.data.close() + except KeyboardInterrupt: + print('Caught interrupt, exiting', file=sys.stderr) + finally: + sel.close() diff --git a/src/protocol.py b/src/protocol.py new file mode 100644 index 0000000..5cada15 --- /dev/null +++ b/src/protocol.py @@ -0,0 +1,450 @@ +# Missive and Parcel classes are 'data transfer' entities which contol +# encoding/decoding, length of an transfered object, +# whereas DefaultHandler contains logic for sending/receiving them +# using selectors - thus it controls when a socket can be closed, +# when to switch the selector from write to read mode, etc. +# There is also a concept of a 'Pattern', a concrete implementation of which +# executes some logic on 3 events: onsent, onreceived, and onprogress. +# for example NoReply pattern closes the handler in onsent() which means +# once a message or a file was sent, no reply is expected and we are done here + + +import sys +import selectors +import json +import io +import struct +import socket +from abc import ABC, abstractmethod +import tempfile +import os +import uuid + + +class Missive: + + def __init__(self, data=None): + self.data = data + self.jsonheader = None + self.payload_processed_b = 0 + self._complete = False + self._recv_buffer = b'' + self._send_buffer = self._encode() + self._jsonheader_len = None + self.missive_id = uuid.uuid4() + self.error = '' + # used for files only + self._file_total_recv = 0 + self._file = None + + def __str__(self): + return f'{self.__class__.__name__} <{self.jsonheader} {self.data or "None"}>' + + def reset(self): + if self._recv_buffer: + raise ProtocolError('Resseting of incoming is not possible') + self.jsonheader = None + self.payload_processed_b = 0 + self.missive_id = uuid.uuid4() + self.error = '' + self._complete = False + self._recv_buffer = b'' + self._send_buffer = self._encode() + self._jsonheader_len = None + + @property + def complete(self): + return self._complete + + def consume_bytes(self, bytes): + # print(f'got {len(bytes)} bytes') + if self.complete: + return bytes + consumed = 0 + self._recv_buffer += bytes + # print(f'recv_buffer {len(self._recv_buffer)} bytes') + if self._jsonheader_len is None: + consumed += self._process_protoheader() + + if self._jsonheader_len is not None: + if self.jsonheader is None: + consumed += self._process_jsonheader() + + if self.jsonheader: + consumed += self._process_payload(bytes) + + bytes = bytes[consumed:] + return bytes + + def _process_protoheader(self): + hdrlen = 2 + read = 0 + if len(self._recv_buffer) >= hdrlen: + self._jsonheader_len = struct.unpack('>H', + self._recv_buffer[:hdrlen])[0] + self._recv_buffer = self._recv_buffer[hdrlen:] + read = hdrlen + return read + + def _process_jsonheader(self): + hdrlen = self._jsonheader_len + read = 0 + if len(self._recv_buffer) >= hdrlen: + self.jsonheader = self._json_decode(self._recv_buffer[:hdrlen], + 'utf-8') + self._recv_buffer = self._recv_buffer[hdrlen:] + for reqhdr in ( + 'byteorder', + 'content-length', + 'content-type', + 'content-encoding', + ): + if reqhdr not in self.jsonheader: + raise ProtocolError(f'Missing required header "{reqhdr}"') + read = hdrlen + return read + + def _process_payload(self, bytes): + if self.jsonheader['content-type'] == 'file': + return self._process_file(bytes) + else: + return self._process_text(bytes) + + def _process_text(self, bytes): + read = 0 + content_len = self.jsonheader['content-length'] + if not len(self._recv_buffer) >= content_len: + read = len(bytes) + else: + content = self._recv_buffer[:content_len] + read = content_len + self._recv_buffer = b'' + if self.jsonheader['content-type'] == 'text/json': + encoding = self.jsonheader['content-encoding'] + self.data = self._json_decode(content, encoding) + elif self.jsonheader['content-type'] == 'text/plain': + encoding = self.jsonheader['content-encoding'] + self.data = content.decode(encoding) + else: + # Binary or unknown content-type + self.data = content + self._complete = True + self.payload_processed_b += read + return read + + def _process_file(self, bytes): + read = 0 + content_len = self.jsonheader['content-length'] + if self._file_total_recv < content_len: + if not self._file: + self.data = os.path.join(tempfile.gettempdir(), + f'{uuid.uuid4()}') + self._file = open(self.data, 'wb') + read = self._file.write(self._recv_buffer) + self._recv_buffer = self._recv_buffer[read:] + else: + self._file.close() + self._file = None + self._complete = True + self._file_total_recv += read + self.payload_processed_b += read + return read + + def yield_bytes(self, consumer): + consumed = self._send_from_buffer(consumer) + if len(self._send_buffer) == 0: + self._complete = True + return consumed + + def _send_from_buffer(self, consumer): + if self._complete: + return 0 + if len(self._send_buffer) == 0: + return 0 + chunk = b'' + if len(self._send_buffer) > 4096: + chunk = self._send_buffer[:4096] + else: + chunk = self._send_buffer + consumed = consumer(chunk) + if consumed is None: + raise ProtocolError('Bytes consumer is of an inappropriate type') + self._send_buffer = self._send_buffer[consumed:] + return consumed + + def _encode(self): + if self.data is None: + return b'' + payload = None + jsonheader = { + 'byteorder': sys.byteorder, + 'content-type': None, + 'content-encoding': 'utf-8', + 'content-length': 0, + } + if type(self.data) is str: + jsonheader['content-type'] = 'text/plain' + payload = self.data.encode('utf-8') + elif type(self.data) is dict: + jsonheader['content-type'] = 'text/json' + payload = self._json_encode(self.data, 'utf-8') + else: + raise ProtocolError('Supported data types: str, dict') + + jsonheader['content-length'] = len(payload) + jsonheader_bytes = self._json_encode(jsonheader, 'utf-8') + missive_hdr = struct.pack('>H', len(jsonheader_bytes)) + encoded = missive_hdr + jsonheader_bytes + payload + return encoded + + def _json_encode(self, obj, encoding): + return json.dumps(obj, ensure_ascii=False).encode(encoding) + + def _json_decode(self, json_bytes, encoding): + tiow = io.TextIOWrapper(io.BytesIO(json_bytes), + encoding=encoding, + newline='') + obj = json.load(tiow) + tiow.close() + return obj + + +class Parcel(Missive): + """Use it for sending a file.""" + + def __init__(self, path): + self.path = path + self.filename = '' + self.size = 0 + super().__init__() + + def yield_bytes(self, consumer): + consumed = self._send_from_buffer(consumer) + consumed += self._send_file(consumer) + return consumed + + def _send_file(self, consumer): + if self._complete: + return 0 + if not self._file: + self._file = open(self.path, 'rb') + + chunk = self._file.read(4096) + if chunk == b'': + self._file.close() + self._file = None + self._complete = True + consumed = 0 + else: + consumed = consumer(chunk) + if consumed is None: + raise ProtocolError( + 'Bytes consumer is of an inappropriate type') + if consumed < len(chunk): + diff = len(chunk) - consumed + self._file.seek(-diff, 1) + self.payload_processed_b += consumed + return consumed + + def _encode(self): + if self.path and os.path.isfile(self.path): + self.filename = os.path.basename(self.path) + self.size = os.path.getsize(self.path) + jsonheader = { + 'byteorder': sys.byteorder, + 'content-type': 'file', + 'content-encoding': 'raw', + 'filename': self.filename, + 'content-length': self.size, + } + + jsonheader_bytes = self._json_encode(jsonheader, 'utf-8') + parcel_hdr = struct.pack('>H', len(jsonheader_bytes)) + encoded = parcel_hdr + jsonheader_bytes + return encoded + + +class ProtocolError(Exception): + + def __init__(self, message): + self.message = message + + def __str__(self): + return f'ProtocolError: {self.message}' + + +class DefaultHandler: + + def __init__(self, selector, sock, addr, tag='handler'): + self.selector = selector + self.sock = sock + self.addr = addr + self.tag = tag + self._queue = [] + self._recv = None + self._callback = NonePattern() + self._canclose = True + self._isclosed = False + + def setpattern(self, pattern): + self._callback = pattern + + def set_selector_events_mask(self, mode): + """Set selector to listen for events: mode is 'r', 'w', or 'rw'.""" + + if self.sock is None: + return + if mode == 'r': + events = selectors.EVENT_READ + elif mode == 'w': + events = selectors.EVENT_WRITE + elif mode == 'rw': + events = selectors.EVENT_READ | selectors.EVENT_WRITE + else: + raise ProtocolError(f'mask mode {mode!r} not recognized') + self.selector.modify(self.sock, events, data=self) + + def enqueue(self, missive): + self._queue.append(missive) + self.set_selector_events_mask('rw') + + def _sendout(self, bytes): + if self.sock.proto == socket.IPPROTO_UDP: + return self.sock.sendto(bytes, self.sock.getsockname()) + else: + return self.sock.send(bytes) + + def _read(self): + try: + bytes = self.sock.recv(4096) + except BlockingIOError as e: + # Resource temporarily unavailable (errno EWOULDBLOCK) + # with the next attempt it might be ok - so pass for now + print(f'BlockingIOError {e}') + else: + read = True + self._canclose = len(bytes) == 0 + while read: + if not self._recv: + self._recv = Missive() + try: + bytes = self._recv.consume_bytes(bytes) + except Exception as e: + self._canclose = True + self._recv.error = f'Failed to receive, {e}' + self._callback.onreceived(self, self._recv) + self._recv = None + read = False + else: + if self._recv.complete: + self._canclose = len(bytes) == 0 + self._callback.onreceived(self, self._recv) + self._recv = None + else: + self._callback.onprogress(self, self._recv) + if bytes: + self._recv = Missive() + read = len(bytes) > 0 + + def _write(self): + if not self._queue or self._queue[0] is None: + self._queue = [] + self.set_selector_events_mask('r') + return + try: + self._queue[0].yield_bytes(self._sendout) + except BlockingIOError as e: + # Resource temporarily unavailable (errno EWOULDBLOCK) + # with the next attempt it might be ok - so pass for now + print(f'BlockingIOError {e}', file=sys.stderr) + except (BrokenPipeError, ConnectionResetError) as e: + miss = self._queue.pop(0) + miss.error = f'Counterparty unable to receive data, {e}' + self._callback.onsent(self, miss) + self.close() + print(miss.error, file=sys.stderr) + else: + if self._queue[0].complete: + self._callback.onsent(self, self._queue.pop(0)) + else: + self._callback.onprogress(self, self._queue[0]) + + if not self._queue and not self._isclosed: + # Set selector to listen for read events, we're done writing. + self.set_selector_events_mask('r') + + def handle_events(self, mask): + if mask & selectors.EVENT_WRITE: + self._write() + if mask & selectors.EVENT_READ: + self._read() + + @property + def canclose(self): + return len(self._queue) == 0 and self._canclose + + def close(self): + try: + if not self._isclosed: + self.selector.unregister(self.sock) + except Exception as e: + print( + f'Error while unregister the socket: selector.unregister() exception for ' + f'{self.addr}: {e!r}', file=sys.stderr) + try: + if not self._isclosed: + self.sock.close() + except OSError as e: + print( + f'{self.__class__.__name__}. Error: socket.close() exception for {self.addr}: {e!r}', + file=sys.stderr + ) + finally: + # Delete reference to socket object for garbage collection + self.sock = None + self._callback = None + self._isclosed = True + self._canclose = True + self._queue = [] + + +class Pattern(ABC): + """Implement a custom logic based on this class.""" + + @abstractmethod + def onprogress(self, handler, missive): + pass + + @abstractmethod + def onsent(self, handler, missive): + pass + + @abstractmethod + def onreceived(self, handler, missive): + pass + + +class NonePattern(Pattern): + + def onprogress(self, handler, missive): + pass + + def onsent(self, handler, missive): + pass + + def onreceived(self, handler, missive): + pass + + +class NoReply(Pattern): + """Simple client pattern that does NOT expect a reply from server.""" + + def onprogress(self, handler, missive): + pass + + def onsent(self, handler, missive): + if handler.canclose: + handler.close() + + def onreceived(self, handler, missive): + pass diff --git a/src/util.py b/src/util.py new file mode 100644 index 0000000..129082f --- /dev/null +++ b/src/util.py @@ -0,0 +1,176 @@ +import socket +import selectors + +from protocol import DefaultHandler +from protocol import NoReply + + +def accept_connection(sel, sock, callback): + """ + Accepts an incoming connection using protocol.DefaultHandler. + + Args: + sel (selector): A selector from selectors module. + sock (socket): A socket from socket module. + callback (Pattern): Implementation of protocol.Pattern. + + Returns: + None. + """ + conn, addr = sock.accept() + tag = addr + if conn.family == socket.AF_UNIX: + tag = f'' + conn.setblocking(False) + handler = DefaultHandler(sel, conn, addr, tag) + handler.setpattern(callback) + events = selectors.EVENT_READ + sel.register(conn, events, data=handler) + + +def register_udp_handler(sel, ip, port): + """ + UPD listening with a predefined protocol.DefaultHandler. + + Args: + sel (selector): A selector from selectors module. + ip (str): Broadcasting lan IP address. + port (int): Port to listen to. + + Returns: + DefaultHandler: The handler registered in selector. + """ + udp_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, + socket.IPPROTO_UDP) + udp_sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1) + udp_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + udp_sock.bind((ip, port)) # no need to listen for UDP + print(f'Listening on {(ip, port)}') + udp_sock.setblocking(False) + udp_handler = DefaultHandler(sel, udp_sock, ip, 'udp') + # set data=udp_handler right here as accept_connection() + # cannot be applied in case of UDP + sel.register(udp_sock, selectors.EVENT_READ, data=udp_handler) + return udp_handler + + +def register_unix_handler(sel, path): + """ + UNIX listening. A handler is to be set on accepting a new connection. + + Args: + sel (selector): A selector from selectors module. + path (str): A path to the UNIX socket. + + Returns: + None. + """ + unix_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + unix_sock.bind(path) + unix_sock.listen(1) + unix_sock.setblocking(False) + sel.register(unix_sock, selectors.EVENT_READ, data=None) + + +def register_tcp_handler(sel, ip, port): + """ + TCP listening. A handler is to be set on accepting a new connection. + + Args: + sel (selector): A selector from selectors module. + ip (str): This machine's lan IP address. + port (int): Port to listen to. + + Returns: + None. + """ + tcp_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + tcp_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + tcp_sock.bind((ip, port)) + tcp_sock.listen() + print(f'Listening on {(ip, port)}') + tcp_sock.setblocking(False) + sel.register(tcp_sock, selectors.EVENT_READ, data=None) + + +def send_via_tcp(sel, ip, port, pattern, *missives): + """ + Initialize outbound TCP connection with protocol.DefaultHandler. + + Args: + sel (selector): A selector from selectors module. + ip (str): This machine's lan IP address. + port (int): Port to listen to. + pattern (Pattern): Implementation of protocol.Pattern. + missives (Missive): A sequence of messages to send. + + Returns: + None. + """ + if not missives: + return + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.setblocking(True) + sock.connect((ip, port)) + events = selectors.EVENT_WRITE + handler = DefaultHandler(sel, sock, ip, 'tcp-out') + handler.setpattern(pattern) + sel.register(sock, events, data=handler) + for missive in missives: + handler.enqueue(missive) + + +def broadcast_via_udp(sel, ip, port, missive): + """ + Initialize UDP broadcast with protocol.DefaultHandler. + + Args: + sel (selector): A selector from selectors module. + ip (str): Broacasting lan IP address. + port (int): Port for broadcasting. + missive (Missive): A message to send out. + + Returns: + None. + """ + if not missive: + return + sock = socket.socket(socket.AF_INET, + socket.SOCK_DGRAM, + socket.IPPROTO_UDP) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.setblocking(True) + sock.bind((ip, port)) + events = selectors.EVENT_WRITE + handler = DefaultHandler(sel, sock, ip, 'udp-out') + handler.setpattern(NoReply()) + sel.register(sock, events, data=handler) + handler.enqueue(missive) + + +def send_via_unix(sel, path, pattern, *missives): + """ + Initialize UNIX socket connection with protocol.DefaultHandler. + + Args: + sel (selector): A selector from selectors module. + path (str): A path to a UNIX socket file. + pattern (Pattern): Implementation of protocol.Pattern. + missives (Missive): A sequence of messages to send. + + Returns: + None. + """ + if not missives: + return + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + sock.setblocking(True) + sock.connect(path) + events = selectors.EVENT_WRITE + handler = DefaultHandler(sel, sock, + path, 'unix-out') + handler.setpattern(pattern) + sel.register(sock, events, data=handler) + for missive in missives: + handler.enqueue(missive)