source: rtems-tools/tester/rt/tftpserver.py @ 7a29b47

Last change on this file since 7a29b47 was 7a29b47, checked in by Chris Johns <chrisj@…>, on 09/07/20 at 00:07:11

tester: Fix TFTP server Python2 issues.

  • Add a --show-backtrace option to make it easier for users to get an exception backtrace if something goes wrong.
  • Fix the --packet-trace option so it actually decodes the packets
  • Property mode set to 100644
File size: 25.9 KB
Line 
1# SPDX-License-Identifier: BSD-2-Clause
2'''The TFTP Server handles a read only TFTP session.'''
3
4# Copyright (C) 2020 Chris Johns (chrisj@rtems.org)
5#
6# Redistribution and use in source and binary forms, with or without
7# modification, are permitted provided that the following conditions
8# are met:
9# 1. Redistributions of source code must retain the above copyright
10#    notice, this list of conditions and the following disclaimer.
11# 2. Redistributions in binary form must reproduce the above copyright
12#    notice, this list of conditions and the following disclaimer in the
13#    documentation and/or other materials provided with the distribution.
14#
15# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
16# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
18# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
19# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
20# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
21# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
22# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
23# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
24# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
25# POSSIBILITY OF SUCH DAMAGE.
26
27from __future__ import print_function
28
29import argparse
30import os
31import socket
32import sys
33import time
34import threading
35
36try:
37    import socketserver
38except ImportError:
39    import SocketServer as socketserver
40
41from rtemstoolkit import error
42from rtemstoolkit import log
43from rtemstoolkit import version
44
45
46class tftp_session(object):
47    '''Handle the TFTP session packets initiated on the TFTP port (69).
48    '''
49    # pylint: disable=useless-object-inheritance
50    # pylint: disable=too-many-instance-attributes
51
52    opcodes = ['nul', 'RRQ', 'WRQ', 'DATA', 'ACK', 'ERROR', 'OACK']
53
54    OP_RRQ = 1
55    OP_WRQ = 2
56    OP_DATA = 3
57    OP_ACK = 4
58    OP_ERROR = 5
59    OP_OACK = 6
60
61    E_NOT_DEFINED = 0
62    E_FILE_NOT_FOUND = 1
63    E_ACCESS_VIOLATION = 2
64    E_DISK_FULL = 3
65    E_ILLEGAL_TFTP_OP = 4
66    E_UKNOWN_TID = 5
67    E_FILE_ALREADY_EXISTS = 6
68    E_NO_SUCH_USER = 7
69    E_NO_ERROR = 10
70
71    def __init__(self, host, port, base, forced_file, reader=None):
72        # pylint: disable=too-many-arguments
73        self.host = host
74        self.port = port
75        self.base = base
76        self.forced_file = forced_file
77        if reader is None:
78            self.data_reader = self._file_reader
79        else:
80            self.data_reader = reader
81        self.filein = None
82        self.resends_limit = 5
83        # These are here to shut pylint up
84        self.block = 0
85        self.last_data = None
86        self.block_size = 512
87        self.timeout = 0
88        self.resends = 0
89        self.finished = False
90        self.filename = None
91        self._reinit()
92
93    def _reinit(self):
94        '''Reinitialise all the class variables used by the protocol.'''
95        if self.filein is not None:
96            self.filein.close()
97            self.filein = None
98        self.block = 0
99        self.last_data = None
100        self.block_size = 512
101        self.timeout = 0
102        self.resends = 0
103        self.finished = False
104        self.filename = None
105
106    def _file_reader(self, command, **kwargs):
107        '''The default file reader if the user does not provide one.
108
109        The call returns a two element tuple where the first element
110        is an error code, and the second element is data if the error
111        code is 0 else it is an error message.
112        '''
113        # pylint: disable=too-many-return-statements
114        if command == 'open':
115            if 'filename' not in kwargs:
116                raise error.general('tftp-reader: invalid open: no filename')
117            filename = kwargs['filename']
118            try:
119                self.filein = open(filename, 'rb')
120                filesize = os.stat(filename).st_size
121            except FileNotFoundError:
122                return self.E_FILE_NOT_FOUND, 'file not found (%s)' % (
123                    filename)
124            except PermissionError:
125                return self.E_ACCESS_VIOLATION, 'access violation'
126            except IOError as ioe:
127                return self.E_NOT_DEFINED, str(ioe)
128            return self.E_NO_ERROR, str(filesize)
129        if command == 'read':
130            if self.filein is None:
131                raise error.general('tftp-reader: read when not open')
132            if 'blksize' not in kwargs:
133                raise error.general('tftp-reader: invalid read: no blksize')
134            # pylint: disable=bare-except
135            try:
136                return self.E_NO_ERROR, self.filein.read(kwargs['blksize'])
137            except IOError as ioe:
138                return self.E_NOT_DEFINED, str(ioe)
139            except:
140                return self.E_NOT_DEFINED, 'unknown error'
141        if command == 'close':
142            if self.filein is not None:
143                self.filein.close()
144                self.filein = None
145            return self.E_NO_ERROR, "closed"
146        return self.E_NOT_DEFINED, 'invalid reader state'
147
148    @staticmethod
149    def _pack_bytes(data=None):
150        bdata = bytearray()
151        if data is not None:
152            if not isinstance(data, list):
153                data = [data]
154            for item in data:
155                if isinstance(item, int):
156                    bdata.append(item >> 8)
157                    bdata.append(item & 0xff)
158                elif isinstance(item, str):
159                    bdata.extend(item.encode())
160                    bdata.append(0)
161                else:
162                    bdata.extend(item)
163        return bdata
164
165    def _response(self, opcode, data):
166        code = self.opcodes.index(opcode)
167        if code == 0 or code >= len(self.opcodes):
168            raise error.general('invalid opcode: ' + opcode)
169        bdata = self._pack_bytes([code, data])
170        #print(''.join(format(x, '02x') for x in bdata))
171        return bdata
172
173    def _error_response(self, code, message):
174        if log.tracing:
175            log.trace('tftp: error: %s:%d: %d: %s' %
176                      (self.host, self.port, code, message))
177        self.finished = True
178        return self._response('ERROR', self._pack_bytes([code, message, 0]))
179
180    def _data_response(self, block, data):
181        if len(data) < self.block_size:
182            self.finished = True
183        return self._response('DATA', self._pack_bytes([block, data]))
184
185    def _oack_response(self, data):
186        self.resends += 1
187        if self.resends >= self.resends_limit:
188            return self._error_response(self.E_NOT_DEFINED,
189                                        'resend limit reached')
190        return self._response('OACK', self._pack_bytes(data))
191
192    def _next_block(self, block):
193        # has the current block been acknowledged?
194        if block == self.block:
195            self.resends = 0
196            self.block += 1
197            err, data = self.data_reader('read', blksize=self.block_size)
198            data = bytearray(data)
199            if err != self.E_NO_ERROR:
200                return self._error_response(err, data)
201            # close if the length of data is less than the block size
202            if len(data) < self.block_size:
203                self.data_reader('close')
204            self.last_data = data
205        else:
206            self.resends += 1
207            if self.resends >= self.resends_limit:
208                return self._error_response(self.E_NOT_DEFINED,
209                                            'resend limit reached')
210            data = self.last_data
211        return self._data_response(self.block, data)
212
213    def _read_req(self, data):
214        # if the last block is not 0 something has gone wrong and
215        # TID match. Restart the session. It could be the client
216        # is a simple implementation that does not move the send
217        # port on each retry.
218        if self.block != 0:
219            self.data_reader('close')
220            self._reinit()
221        # Get the filename, mode and options
222        self.filename = self.get_option('filename', data)
223        if self.filename is None:
224            return self._error_response(self.E_NOT_DEFINED,
225                                        'filename not found in request')
226        if self.forced_file is not None:
227            self.filename = self.forced_file
228        # open the reader
229        err, message = self.data_reader('open', filename=self.filename)
230        if err != self.E_NO_ERROR:
231            return self._error_response(err, message)
232        # the no error on open message is the file size
233        try:
234            tsize = int(message)
235        except ValueError:
236            tsize = 0
237        mode = self.get_option('mode', data)
238        if mode is None:
239            return self._error_response(self.E_NOT_DEFINED,
240                                        'mode not found in request')
241        oack_data = self._pack_bytes()
242        value = self.get_option('timeout', data)
243        if value is not None:
244            oack_data += self._pack_bytes(['timeout', value])
245            self.timeout = int(value)
246        value = self.get_option('blksize', data)
247        if value is not None:
248            oack_data += self._pack_bytes(['blksize', value])
249            self.block_size = int(value)
250        else:
251            self.block_size = 512
252        value = self.get_option('tsize', data)
253        if value is not None and tsize > 0:
254            oack_data += self._pack_bytes(['tsize', str(tsize)])
255        # Send the options ack
256        return self._oack_response(oack_data)
257
258    def _write_req(self):
259        # WRQ is not supported
260        return self._error_response(self.E_ILLEGAL_TFTP_OP,
261                                    "writes not supported")
262
263    def _op_ack(self, data):
264        # send the next block of data
265        block = (data[2] << 8) | data[3]
266        return self._next_block(block)
267
268    def process(self, host, port, data):
269        '''Process the incoming client data sending a response. If the session
270        has finished return None.
271        '''
272        if host != self.host and port != self.port:
273            return self._error_response(self.E_UKNOWN_TID,
274                                        'unkown transfer ID')
275        if self.finished:
276            return None
277        opcode = (data[0] << 8) | data[1]
278        if opcode == self.OP_RRQ:
279            return self._read_req(data)
280        if opcode in [self.OP_WRQ, self.OP_DATA]:
281            return self._write_req()
282        if opcode == self.OP_ACK:
283            return self._op_ack(data)
284        return self._error_response(self.E_ILLEGAL_TFTP_OP,
285                                    "unknown or unsupported opcode")
286
287    def decode(self, host, port, data):
288        '''Decode the packet for diagnostic purposes.
289        '''
290        # pylint: disable=too-many-branches
291        out = ''
292        dlen = len(data)
293        if dlen > 2:
294            opcode = (data[0] << 8) | data[1]
295            if 0 < opcode < len(self.opcodes):
296                if opcode in [self.OP_RRQ, self.OP_WRQ]:
297                    out += '  ' + self.opcodes[opcode] + ', '
298                    i = 2
299                    while data[i] != 0:
300                        out += chr(data[i])
301                        i += 1
302                    while i < dlen - 1:
303                        out += ', '
304                        i += 1
305                        while data[i] != 0:
306                            out += chr(data[i])
307                            i += 1
308                elif opcode == self.OP_DATA:
309                    block = (data[2] << 8) | data[3]
310                    out += '  ' + self.opcodes[opcode] + ', '
311                    out += '#' + str(block) + ', '
312                    if dlen > 4:
313                        out += '%02x%02x..%02x%02x' % (data[4], data[5],
314                                                       data[-2], data[-1])
315                    else:
316                        out += '%02x%02x%02x%02x' % (data[4], data[5], data[6],
317                                                     data[6])
318                    out += ',' + str(dlen - 4)
319                elif opcode == self.OP_ACK:
320                    block = (data[2] << 8) | data[3]
321                    out += '  ' + self.opcodes[opcode] + ' ' + str(block)
322                elif opcode == self.OP_ERROR:
323                    out += 'E ' + self.opcodes[opcode] + ', '
324                    out += str((data[2] << 8) | (data[3]))
325                    out += ': ' + str(data[4:].decode())
326                    i = 2
327                    while data[i] != 0:
328                        out += chr(data[i])
329                        i += 1
330                elif opcode == self.OP_OACK:
331                    out += '  ' + self.opcodes[opcode]
332                    i = 1
333                    while i < dlen - 1:
334                        out += ', '
335                        i += 1
336                        while data[i] != 0:
337                            out += chr(data[i])
338                            i += 1
339            else:
340                out += 'E INV(%d)' % (opcode)
341        else:
342            out += 'E INVALID LENGTH'
343        return out[:2] + '[%s:%d] (%d) ' % (host, port, len(data)) + out[2:]
344
345    @staticmethod
346    def get_option(option, data):
347        '''Get the option from the TFTP packet.'''
348        dlen = len(data) - 1
349        opcode = (data[0] << 8) | data[1]
350        next_option = False
351        if opcode in [1, 2]:
352            count = 0
353            i = 2
354            while i < dlen:
355                value = ''
356                while data[i] != 0:
357                    value += chr(data[i])
358                    i += 1
359                i += 1
360                if option == 'filename' and count == 0:
361                    return value
362                if option == 'mode' and count == 1:
363                    return value
364                if value == option and (count % 1) == 0:
365                    next_option = True
366                elif next_option:
367                    return value
368                count += 1
369        return None
370
371    def get_timeout(self, default_timeout, timeout_guard):
372        '''Get the timeout. The timeout can be an option.'''
373        if self.timeout == 0:
374            return self.timeout + timeout_guard
375        return default_timeout
376
377    def get_block_size(self):
378        '''Get the block size. The block size can be an option.'''
379        return self.block_size
380
381
382class udp_handler(socketserver.BaseRequestHandler):
383    '''TFTP UDP handler for a TFTP session.'''
384    def _notice(self, text):
385        if self.server.tftp.notices:
386            log.notice(text)
387        else:
388            log.trace(text)
389
390    def handle_session(self, index):
391        '''Handle the TFTP session data.'''
392        # pylint: disable=too-many-locals
393        # pylint: disable=broad-except
394        # pylint: disable=too-many-branches
395        # pylint: disable=too-many-statements
396        client_ip = self.client_address[0]
397        client_port = self.client_address[1]
398        client = '%s:%i' % (client_ip, client_port)
399        self._notice('] tftp: %d: start: %s' % (index, client))
400        try:
401            session = tftp_session(client_ip, client_port,
402                                   self.server.tftp.base,
403                                   self.server.tftp.forced_file,
404                                   self.server.tftp.reader)
405            data = bytearray(self.request[0])
406            response = session.process(client_ip, client_port, data)
407            if response is not None:
408                if log.tracing and self.server.tftp.packet_trace:
409                    log.trace(' > ' +
410                              session.decode(client_ip, client_port, data))
411                timeout = session.get_timeout(self.server.tftp.timeout, 1)
412                sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
413                sock.bind(('', 0))
414                sock.settimeout(timeout)
415                while response is not None:
416                    if log.tracing and self.server.tftp.packet_trace:
417                        log.trace(
418                            ' < ' +
419                            session.decode(client_ip, client_port, response))
420                    sock.sendto(response, (client_ip, client_port))
421                    if session.finished:
422                        break
423                    try:
424                        data, address = sock.recvfrom(2 + 2 +
425                                                      session.get_block_size())
426                        data = bytearray(data)
427                        if log.tracing and self.server.tftp.packet_trace:
428                            log.trace(
429                                ' > ' +
430                                session.decode(address[0], address[1], data))
431                    except socket.error as serr:
432                        if log.tracing:
433                            log.trace('] tftp: %d: receive: %s: error: %s' \
434                                      % (index, client, serr))
435                        return
436                    except socket.gaierror as serr:
437                        if log.tracing:
438                            log.trace('] tftp: %d: receive: %s: error: %s' \
439                                      % (index, client, serr))
440                        return
441                    response = session.process(address[0], address[1], data)
442        except error.general as gerr:
443            self._notice('] tftp: %dd: error: %s' % (index, gerr))
444        except error.internal as ierr:
445            self._notice('] tftp: %d: error: %s' % (index, ierr))
446        except error.exit:
447            pass
448        except KeyboardInterrupt:
449            pass
450        except Exception as exp:
451            if self.server.tftp.exception_is_raise:
452                raise
453            self._notice('] tftp: %d: error: %s: %s' % (index, type(exp), exp))
454        self._notice('] tftp: %d: end: %s' % (index, client))
455
456    def handle(self):
457        '''The UDP server handle method.'''
458        if self.server.tftp.sessions is None \
459           or self.server.tftp.session < self.server.tftp.sessions:
460            self.handle_session(self.server.tftp.next_session())
461
462
463class udp_server(socketserver.ThreadingMixIn, socketserver.UDPServer):
464    '''UDP server. Default behaviour.'''
465
466
467class tftp_server(object):
468    '''TFTP server runs a UDP server to handle TFTP sessions.'''
469
470    # pylint: disable=useless-object-inheritance
471    # pylint: disable=too-many-instance-attributes
472
473    def __init__(self,
474                 host,
475                 port,
476                 timeout=10,
477                 base=None,
478                 forced_file=None,
479                 sessions=None,
480                 reader=None):
481        # pylint: disable=too-many-arguments
482        self.lock = threading.Lock()
483        self.notices = False
484        self.packet_trace = False
485        self.exception_is_raise = False
486        self.timeout = timeout
487        self.host = host
488        self.port = port
489        self.server = None
490        self.server_thread = None
491        if base is None:
492            base = os.getcwd()
493        self.base = base
494        self.forced_file = forced_file
495        if sessions is not None and not isinstance(sessions, int):
496            raise error.general('tftp session count is not a number')
497        self.sessions = sessions
498        self.session = 0
499        self.reader = reader
500
501    def __del__(self):
502        self.stop()
503
504    def _lock(self):
505        self.lock.acquire()
506
507    def _unlock(self):
508        self.lock.release()
509
510    def start(self):
511        '''Start the TFTP server. Returns once started.'''
512        # pylint: disable=attribute-defined-outside-init
513        if log.tracing:
514            log.trace('] tftp: server: %s:%i' % (self.host, self.port))
515        if self.host == 'all':
516            host = ''
517        else:
518            host = self.host
519        try:
520            self.server = udp_server((host, self.port), udp_handler)
521        except Exception as exp:
522            raise error.general('tftp server create: %s' % (exp))
523        # We cannot set tftp in __init__ because the object is created
524        # in a separate package.
525        self.server.tftp = self
526        self.server_thread = threading.Thread(target=self.server.serve_forever)
527        self.server_thread.daemon = True
528        self.server_thread.start()
529
530    def stop(self):
531        '''Stop the TFTP server and close the server port.'''
532        self._lock()
533        try:
534            if self.server is not None:
535                self.server.shutdown()
536                self.server.server_close()
537                self.server = None
538        finally:
539            self._unlock()
540
541    def run(self):
542        '''Run the TFTP server for the specified number of sessions.'''
543        running = True
544        while running:
545            period = 1
546            self._lock()
547            if self.server is None:
548                running = False
549                period = 0
550            elif self.sessions is not None:
551                if self.sessions == 0:
552                    running = False
553                    period = 0
554                else:
555                    period = 0.25
556            self._unlock()
557            if period > 0:
558                time.sleep(period)
559        self.stop()
560
561    def get_session(self):
562        '''Return the session count.'''
563        count = 0
564        self._lock()
565        try:
566            count = self.session
567        finally:
568            self._unlock()
569        return count
570
571    def next_session(self):
572        '''Return the next session number.'''
573        count = 0
574        self._lock()
575        try:
576            self.session += 1
577            count = self.session
578        finally:
579            self._unlock()
580        return count
581
582    def enable_notices(self):
583        '''Call to enable notices. The server is quiet without this call.'''
584        self._lock()
585        self.notices = True
586        self._unlock()
587
588    def trace_packets(self):
589        '''Call to enable packet tracing as a diagnostic.'''
590        self._lock()
591        self.packet_trace = True
592        self._unlock()
593
594    def except_is_raise(self):
595        '''If True a standard exception will generate a backtrace.'''
596        self.exception_is_raise = True
597
598
599def load_log(logfile):
600    '''Set the log file.'''
601    if logfile is None:
602        log.default = log.log(streams=['stdout'])
603    else:
604        log.default = log.log(streams=[logfile])
605
606
607def run(args=sys.argv, command_path=None):
608    '''Run a TFTP server session.'''
609    # pylint: disable=dangerous-default-value
610    # pylint: disable=unused-argument
611    # pylint: disable=too-many-branches
612    # pylint: disable=too-many-statements
613    ecode = 0
614    notice = None
615    server = None
616    # pylint: disable=bare-except
617    try:
618        description = 'A TFTP Server that supports a read only TFTP session.'
619
620        nice_cwd = os.path.relpath(os.getcwd())
621        if len(nice_cwd) > len(os.path.abspath(nice_cwd)):
622            nice_cwd = os.path.abspath(nice_cwd)
623
624        argsp = argparse.ArgumentParser(prog='rtems-tftp-server',
625                                        description=description)
626        argsp.add_argument('-l',
627                           '--log',
628                           help='log file.',
629                           type=str,
630                           default=None)
631        argsp.add_argument('-v',
632                           '--trace',
633                           help='enable trace logging for debugging.',
634                           action='store_true',
635                           default=False)
636        argsp.add_argument('--trace-packets',
637                           help='enable trace logging of packets.',
638                           action='store_true',
639                           default=False)
640        argsp.add_argument('--show-backtrace',
641                           help='show the exception backtrace.',
642                           action='store_true',
643                           default=False)
644        argsp.add_argument(
645            '-B',
646            '--bind',
647            help='address to bind the server too (default: %(default)s).',
648            type=str,
649            default='all')
650        argsp.add_argument(
651            '-P',
652            '--port',
653            help='port to bind the server too (default: %(default)s).',
654            type=int,
655            default='69')
656        argsp.add_argument('-t', '--timeout',
657                           help = 'timeout in seconds, client can override ' \
658                           '(default: %(default)s).',
659                           type = int, default = '10')
660        argsp.add_argument(
661            '-b',
662            '--base',
663            help='base path, not checked (default: %(default)s).',
664            type=str,
665            default=nice_cwd)
666        argsp.add_argument(
667            '-F',
668            '--force-file',
669            help='force the file to be downloaded overriding the client.',
670            type=str,
671            default=None)
672        argsp.add_argument('-s', '--sessions',
673                           help = 'number of TFTP sessions to run before exiting ' \
674                           '(default: forever.',
675                           type = int, default = None)
676
677        argopts = argsp.parse_args(args[1:])
678
679        load_log(argopts.log)
680        log.notice('RTEMS Tools - TFTP Server, %s' % (version.string()))
681        log.output(log.info(args))
682        log.tracing = argopts.trace
683
684        server = tftp_server(argopts.bind, argopts.port, argopts.timeout,
685                             argopts.base, argopts.force_file,
686                             argopts.sessions)
687        server.enable_notices()
688        if argopts.trace_packets:
689            server.trace_packets()
690        if argopts.show_backtrace:
691            server.except_is_raise()
692
693        try:
694            server.start()
695            server.run()
696        finally:
697            server.stop()
698
699    except error.general as gerr:
700        notice = str(gerr)
701        ecode = 1
702    except error.internal as ierr:
703        notice = str(ierr)
704        ecode = 1
705    except error.exit:
706        pass
707    except KeyboardInterrupt:
708        notice = 'abort: user terminated'
709        ecode = 1
710    except SystemExit:
711        pass
712    except:
713        notice = 'abort: unknown error'
714        ecode = 1
715    if server is not None:
716        del server
717    if notice is not None:
718        log.stderr(notice)
719    sys.exit(ecode)
720
721
722if __name__ == "__main__":
723    run()
Note: See TracBrowser for help on using the repository browser.