1 | """This module implements all contexts for state handling during uploads and |
---|
2 | downloads, the main interface to which being the TftpContext base class. |
---|
3 | |
---|
4 | The concept is simple. Each context object represents a single upload or |
---|
5 | download, and the state object in the context object represents the current |
---|
6 | state of that transfer. The state object has a handle() method that expects |
---|
7 | the next packet in the transfer, and returns a state object until the transfer |
---|
8 | is complete, at which point it returns None. That is, unless there is a fatal |
---|
9 | error, in which case a TftpException is returned instead.""" |
---|
10 | |
---|
11 | from __future__ import absolute_import, division, print_function, unicode_literals |
---|
12 | from .TftpShared import * |
---|
13 | from .TftpPacketTypes import * |
---|
14 | from .TftpPacketFactory import TftpPacketFactory |
---|
15 | from .TftpStates import * |
---|
16 | import socket, time, sys |
---|
17 | |
---|
18 | ############################################################################### |
---|
19 | # Utility classes |
---|
20 | ############################################################################### |
---|
21 | |
---|
22 | class TftpMetrics(object): |
---|
23 | """A class representing metrics of the transfer.""" |
---|
24 | def __init__(self): |
---|
25 | # Bytes transferred |
---|
26 | self.bytes = 0 |
---|
27 | # Bytes re-sent |
---|
28 | self.resent_bytes = 0 |
---|
29 | # Duplicate packets received |
---|
30 | self.dups = {} |
---|
31 | self.dupcount = 0 |
---|
32 | # Times |
---|
33 | self.start_time = 0 |
---|
34 | self.end_time = 0 |
---|
35 | self.duration = 0 |
---|
36 | # Rates |
---|
37 | self.bps = 0 |
---|
38 | self.kbps = 0 |
---|
39 | # Generic errors |
---|
40 | self.errors = 0 |
---|
41 | |
---|
42 | def compute(self): |
---|
43 | # Compute transfer time |
---|
44 | self.duration = self.end_time - self.start_time |
---|
45 | if self.duration == 0: |
---|
46 | self.duration = 1 |
---|
47 | log.debug("TftpMetrics.compute: duration is %s", self.duration) |
---|
48 | self.bps = (self.bytes * 8.0) / self.duration |
---|
49 | self.kbps = self.bps / 1024.0 |
---|
50 | log.debug("TftpMetrics.compute: kbps is %s", self.kbps) |
---|
51 | for key in self.dups: |
---|
52 | self.dupcount += self.dups[key] |
---|
53 | |
---|
54 | def add_dup(self, pkt): |
---|
55 | """This method adds a dup for a packet to the metrics.""" |
---|
56 | log.debug("Recording a dup of %s", pkt) |
---|
57 | s = str(pkt) |
---|
58 | if s in self.dups: |
---|
59 | self.dups[s] += 1 |
---|
60 | else: |
---|
61 | self.dups[s] = 1 |
---|
62 | tftpassert(self.dups[s] < MAX_DUPS, "Max duplicates reached") |
---|
63 | |
---|
64 | ############################################################################### |
---|
65 | # Context classes |
---|
66 | ############################################################################### |
---|
67 | |
---|
68 | class TftpContext(object): |
---|
69 | """The base class of the contexts.""" |
---|
70 | |
---|
71 | def __init__(self, host, port, timeout, localip = ""): |
---|
72 | """Constructor for the base context, setting shared instance |
---|
73 | variables.""" |
---|
74 | self.file_to_transfer = None |
---|
75 | self.fileobj = None |
---|
76 | self.options = None |
---|
77 | self.packethook = None |
---|
78 | self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) |
---|
79 | if localip != "": |
---|
80 | self.sock.bind((localip, 0)) |
---|
81 | self.sock.settimeout(timeout) |
---|
82 | self.timeout = timeout |
---|
83 | self.state = None |
---|
84 | self.next_block = 0 |
---|
85 | self.factory = TftpPacketFactory() |
---|
86 | # Note, setting the host will also set self.address, as it's a property. |
---|
87 | self.host = host |
---|
88 | self.port = port |
---|
89 | # The port associated with the TID |
---|
90 | self.tidport = None |
---|
91 | # Metrics |
---|
92 | self.metrics = TftpMetrics() |
---|
93 | # Fluag when the transfer is pending completion. |
---|
94 | self.pending_complete = False |
---|
95 | # Time when this context last received any traffic. |
---|
96 | # FIXME: does this belong in metrics? |
---|
97 | self.last_update = 0 |
---|
98 | # The last packet we sent, if applicable, to make resending easy. |
---|
99 | self.last_pkt = None |
---|
100 | # Count the number of retry attempts. |
---|
101 | self.retry_count = 0 |
---|
102 | |
---|
103 | def getBlocksize(self): |
---|
104 | """Fetch the current blocksize for this session.""" |
---|
105 | return int(self.options.get('blksize', 512)) |
---|
106 | |
---|
107 | def __del__(self): |
---|
108 | """Simple destructor to try to call housekeeping in the end method if |
---|
109 | not called explicitely. Leaking file descriptors is not a good |
---|
110 | thing.""" |
---|
111 | self.end() |
---|
112 | |
---|
113 | def checkTimeout(self, now): |
---|
114 | """Compare current time with last_update time, and raise an exception |
---|
115 | if we're over the timeout time.""" |
---|
116 | log.debug("checking for timeout on session %s", self) |
---|
117 | if now - self.last_update > self.timeout: |
---|
118 | raise TftpTimeout("Timeout waiting for traffic") |
---|
119 | |
---|
120 | def start(self): |
---|
121 | raise NotImplementedError("Abstract method") |
---|
122 | |
---|
123 | def end(self): |
---|
124 | """Perform session cleanup, since the end method should always be |
---|
125 | called explicitely by the calling code, this works better than the |
---|
126 | destructor.""" |
---|
127 | log.debug("in TftpContext.end") |
---|
128 | self.sock.close() |
---|
129 | if self.fileobj is not None and not self.fileobj.closed: |
---|
130 | log.debug("self.fileobj is open - closing") |
---|
131 | self.fileobj.close() |
---|
132 | |
---|
133 | def gethost(self): |
---|
134 | "Simple getter method for use in a property." |
---|
135 | return self.__host |
---|
136 | |
---|
137 | def sethost(self, host): |
---|
138 | """Setter method that also sets the address property as a result |
---|
139 | of the host that is set.""" |
---|
140 | self.__host = host |
---|
141 | self.address = socket.gethostbyname(host) |
---|
142 | |
---|
143 | host = property(gethost, sethost) |
---|
144 | |
---|
145 | def setNextBlock(self, block): |
---|
146 | if block >= 2 ** 16: |
---|
147 | log.debug("Block number rollover to 0 again") |
---|
148 | block = 0 |
---|
149 | self.__eblock = block |
---|
150 | |
---|
151 | def getNextBlock(self): |
---|
152 | return self.__eblock |
---|
153 | |
---|
154 | next_block = property(getNextBlock, setNextBlock) |
---|
155 | |
---|
156 | def cycle(self): |
---|
157 | """Here we wait for a response from the server after sending it |
---|
158 | something, and dispatch appropriate action to that response.""" |
---|
159 | try: |
---|
160 | (buffer, (raddress, rport)) = self.sock.recvfrom(MAX_BLKSIZE) |
---|
161 | except socket.timeout: |
---|
162 | log.warn("Timeout waiting for traffic, retrying...") |
---|
163 | raise TftpTimeout("Timed-out waiting for traffic") |
---|
164 | |
---|
165 | # Ok, we've received a packet. Log it. |
---|
166 | log.debug("Received %d bytes from %s:%s", |
---|
167 | len(buffer), raddress, rport) |
---|
168 | # And update our last updated time. |
---|
169 | self.last_update = time.time() |
---|
170 | |
---|
171 | # Decode it. |
---|
172 | recvpkt = self.factory.parse(buffer) |
---|
173 | |
---|
174 | # Check for known "connection". |
---|
175 | if raddress != self.address: |
---|
176 | log.warn("Received traffic from %s, expected host %s. Discarding" |
---|
177 | % (raddress, self.host)) |
---|
178 | |
---|
179 | if self.tidport and self.tidport != rport: |
---|
180 | log.warn("Received traffic from %s:%s but we're " |
---|
181 | "connected to %s:%s. Discarding." |
---|
182 | % (raddress, rport, |
---|
183 | self.host, self.tidport)) |
---|
184 | |
---|
185 | # If there is a packethook defined, call it. We unconditionally |
---|
186 | # pass all packets, it's up to the client to screen out different |
---|
187 | # kinds of packets. This way, the client is privy to things like |
---|
188 | # negotiated options. |
---|
189 | if self.packethook: |
---|
190 | self.packethook(recvpkt) |
---|
191 | |
---|
192 | # And handle it, possibly changing state. |
---|
193 | self.state = self.state.handle(recvpkt, raddress, rport) |
---|
194 | # If we didn't throw any exceptions here, reset the retry_count to |
---|
195 | # zero. |
---|
196 | self.retry_count = 0 |
---|
197 | |
---|
198 | class TftpContextServer(TftpContext): |
---|
199 | """The context for the server.""" |
---|
200 | def __init__(self, |
---|
201 | host, |
---|
202 | port, |
---|
203 | timeout, |
---|
204 | root, |
---|
205 | dyn_file_func=None, |
---|
206 | upload_open=None): |
---|
207 | TftpContext.__init__(self, |
---|
208 | host, |
---|
209 | port, |
---|
210 | timeout, |
---|
211 | ) |
---|
212 | # At this point we have no idea if this is a download or an upload. We |
---|
213 | # need to let the start state determine that. |
---|
214 | self.state = TftpStateServerStart(self) |
---|
215 | |
---|
216 | self.root = root |
---|
217 | self.dyn_file_func = dyn_file_func |
---|
218 | self.upload_open = upload_open |
---|
219 | |
---|
220 | def __str__(self): |
---|
221 | return "%s:%s %s" % (self.host, self.port, self.state) |
---|
222 | |
---|
223 | def start(self, buffer): |
---|
224 | """Start the state cycle. Note that the server context receives an |
---|
225 | initial packet in its start method. Also note that the server does not |
---|
226 | loop on cycle(), as it expects the TftpServer object to manage |
---|
227 | that.""" |
---|
228 | log.debug("In TftpContextServer.start") |
---|
229 | self.metrics.start_time = time.time() |
---|
230 | log.debug("Set metrics.start_time to %s", self.metrics.start_time) |
---|
231 | # And update our last updated time. |
---|
232 | self.last_update = time.time() |
---|
233 | |
---|
234 | pkt = self.factory.parse(buffer) |
---|
235 | log.debug("TftpContextServer.start() - factory returned a %s", pkt) |
---|
236 | |
---|
237 | # Call handle once with the initial packet. This should put us into |
---|
238 | # the download or the upload state. |
---|
239 | self.state = self.state.handle(pkt, |
---|
240 | self.host, |
---|
241 | self.port) |
---|
242 | |
---|
243 | def end(self): |
---|
244 | """Finish up the context.""" |
---|
245 | TftpContext.end(self) |
---|
246 | self.metrics.end_time = time.time() |
---|
247 | log.debug("Set metrics.end_time to %s", self.metrics.end_time) |
---|
248 | self.metrics.compute() |
---|
249 | |
---|
250 | class TftpContextClientUpload(TftpContext): |
---|
251 | """The upload context for the client during an upload. |
---|
252 | Note: If input is a hyphen, then we will use stdin.""" |
---|
253 | def __init__(self, |
---|
254 | host, |
---|
255 | port, |
---|
256 | filename, |
---|
257 | input, |
---|
258 | options, |
---|
259 | packethook, |
---|
260 | timeout, |
---|
261 | localip = ""): |
---|
262 | TftpContext.__init__(self, |
---|
263 | host, |
---|
264 | port, |
---|
265 | timeout, |
---|
266 | localip) |
---|
267 | self.file_to_transfer = filename |
---|
268 | self.options = options |
---|
269 | self.packethook = packethook |
---|
270 | # If the input object has a read() function, |
---|
271 | # assume it is file-like. |
---|
272 | if hasattr(input, 'read'): |
---|
273 | self.fileobj = input |
---|
274 | elif input == '-': |
---|
275 | self.fileobj = sys.stdin |
---|
276 | else: |
---|
277 | self.fileobj = open(input, "rb") |
---|
278 | |
---|
279 | log.debug("TftpContextClientUpload.__init__()") |
---|
280 | log.debug("file_to_transfer = %s, options = %s" % |
---|
281 | (self.file_to_transfer, self.options)) |
---|
282 | |
---|
283 | def __str__(self): |
---|
284 | return "%s:%s %s" % (self.host, self.port, self.state) |
---|
285 | |
---|
286 | def start(self): |
---|
287 | log.info("Sending tftp upload request to %s" % self.host) |
---|
288 | log.info(" filename -> %s" % self.file_to_transfer) |
---|
289 | log.info(" options -> %s" % self.options) |
---|
290 | |
---|
291 | self.metrics.start_time = time.time() |
---|
292 | log.debug("Set metrics.start_time to %s" % self.metrics.start_time) |
---|
293 | |
---|
294 | # FIXME: put this in a sendWRQ method? |
---|
295 | pkt = TftpPacketWRQ() |
---|
296 | pkt.filename = self.file_to_transfer |
---|
297 | pkt.mode = "octet" # FIXME - shouldn't hardcode this |
---|
298 | pkt.options = self.options |
---|
299 | self.sock.sendto(pkt.encode().buffer, (self.host, self.port)) |
---|
300 | self.next_block = 1 |
---|
301 | self.last_pkt = pkt |
---|
302 | # FIXME: should we centralize sendto operations so we can refactor all |
---|
303 | # saving of the packet to the last_pkt field? |
---|
304 | |
---|
305 | self.state = TftpStateSentWRQ(self) |
---|
306 | |
---|
307 | while self.state: |
---|
308 | try: |
---|
309 | log.debug("State is %s" % self.state) |
---|
310 | self.cycle() |
---|
311 | except TftpTimeout as err: |
---|
312 | log.error(str(err)) |
---|
313 | self.retry_count += 1 |
---|
314 | if self.retry_count >= TIMEOUT_RETRIES: |
---|
315 | log.debug("hit max retries, giving up") |
---|
316 | raise |
---|
317 | else: |
---|
318 | log.warn("resending last packet") |
---|
319 | self.state.resendLast() |
---|
320 | |
---|
321 | def end(self): |
---|
322 | """Finish up the context.""" |
---|
323 | TftpContext.end(self) |
---|
324 | self.metrics.end_time = time.time() |
---|
325 | log.debug("Set metrics.end_time to %s" % self.metrics.end_time) |
---|
326 | self.metrics.compute() |
---|
327 | |
---|
328 | |
---|
329 | class TftpContextClientDownload(TftpContext): |
---|
330 | """The download context for the client during a download. |
---|
331 | Note: If output is a hyphen, then the output will be sent to stdout.""" |
---|
332 | def __init__(self, |
---|
333 | host, |
---|
334 | port, |
---|
335 | filename, |
---|
336 | output, |
---|
337 | options, |
---|
338 | packethook, |
---|
339 | timeout, |
---|
340 | localip = ""): |
---|
341 | TftpContext.__init__(self, |
---|
342 | host, |
---|
343 | port, |
---|
344 | timeout, |
---|
345 | localip) |
---|
346 | # FIXME: should we refactor setting of these params? |
---|
347 | self.file_to_transfer = filename |
---|
348 | self.options = options |
---|
349 | self.packethook = packethook |
---|
350 | # If the output object has a write() function, |
---|
351 | # assume it is file-like. |
---|
352 | if hasattr(output, 'write'): |
---|
353 | self.fileobj = output |
---|
354 | # If the output filename is -, then use stdout |
---|
355 | elif output == '-': |
---|
356 | self.fileobj = sys.stdout |
---|
357 | else: |
---|
358 | self.fileobj = open(output, "wb") |
---|
359 | |
---|
360 | log.debug("TftpContextClientDownload.__init__()") |
---|
361 | log.debug("file_to_transfer = %s, options = %s" % |
---|
362 | (self.file_to_transfer, self.options)) |
---|
363 | |
---|
364 | def __str__(self): |
---|
365 | return "%s:%s %s" % (self.host, self.port, self.state) |
---|
366 | |
---|
367 | def start(self): |
---|
368 | """Initiate the download.""" |
---|
369 | log.info("Sending tftp download request to %s" % self.host) |
---|
370 | log.info(" filename -> %s" % self.file_to_transfer) |
---|
371 | log.info(" options -> %s" % self.options) |
---|
372 | |
---|
373 | self.metrics.start_time = time.time() |
---|
374 | log.debug("Set metrics.start_time to %s" % self.metrics.start_time) |
---|
375 | |
---|
376 | # FIXME: put this in a sendRRQ method? |
---|
377 | pkt = TftpPacketRRQ() |
---|
378 | pkt.filename = self.file_to_transfer |
---|
379 | pkt.mode = "octet" # FIXME - shouldn't hardcode this |
---|
380 | pkt.options = self.options |
---|
381 | self.sock.sendto(pkt.encode().buffer, (self.host, self.port)) |
---|
382 | self.next_block = 1 |
---|
383 | self.last_pkt = pkt |
---|
384 | |
---|
385 | self.state = TftpStateSentRRQ(self) |
---|
386 | |
---|
387 | while self.state: |
---|
388 | try: |
---|
389 | log.debug("State is %s" % self.state) |
---|
390 | self.cycle() |
---|
391 | except TftpTimeout as err: |
---|
392 | log.error(str(err)) |
---|
393 | self.retry_count += 1 |
---|
394 | if self.retry_count >= TIMEOUT_RETRIES: |
---|
395 | log.debug("hit max retries, giving up") |
---|
396 | raise |
---|
397 | else: |
---|
398 | log.warn("resending last packet") |
---|
399 | self.state.resendLast() |
---|
400 | |
---|
401 | def end(self): |
---|
402 | """Finish up the context.""" |
---|
403 | TftpContext.end(self) |
---|
404 | self.metrics.end_time = time.time() |
---|
405 | log.debug("Set metrics.end_time to %s" % self.metrics.end_time) |
---|
406 | self.metrics.compute() |
---|