sbase

suckless unix tools
git clone git://git.suckless.org/sbase
Log | Files | Refs | README | LICENSE

tftp.c (5791B)


      1 /* See LICENSE file for copyright and license details. */
      2 #include <sys/time.h>
      3 #include <sys/types.h>
      4 #include <sys/socket.h>
      5 
      6 #include <netdb.h>
      7 #include <netinet/in.h>
      8 
      9 #include <errno.h>
     10 #include <stdio.h>
     11 #include <stdlib.h>
     12 #include <string.h>
     13 #include <unistd.h>
     14 
     15 #include "util.h"
     16 
     17 #define BLKSIZE 512
     18 #define HDRSIZE 4
     19 #define PKTSIZE (BLKSIZE + HDRSIZE)
     20 
     21 #define TIMEOUT_SEC 5
     22 /* transfer will time out after NRETRIES * TIMEOUT_SEC */
     23 #define NRETRIES 5
     24 
     25 #define RRQ  1
     26 #define WWQ  2
     27 #define DATA 3
     28 #define ACK  4
     29 #define ERR  5
     30 
     31 static char *errtext[] = {
     32 	"Undefined",
     33 	"File not found",
     34 	"Access violation",
     35 	"Disk full or allocation exceeded",
     36 	"Illegal TFTP operation",
     37 	"Unknown transfer ID",
     38 	"File already exists",
     39 	"No such user"
     40 };
     41 
     42 static struct sockaddr_storage to;
     43 static socklen_t tolen;
     44 static int timeout;
     45 static int state;
     46 static int s;
     47 
     48 static int
     49 packreq(unsigned char *buf, int op, char *path, char *mode)
     50 {
     51 	unsigned char *p = buf;
     52 
     53 	*p++ = op >> 8;
     54 	*p++ = op & 0xff;
     55 	if (strlen(path) + 1 > 256)
     56 		eprintf("filename too long\n");
     57 	memcpy(p, path, strlen(path) + 1);
     58 	p += strlen(path) + 1;
     59 	memcpy(p, mode, strlen(mode) + 1);
     60 	p += strlen(mode) + 1;
     61 	return p - buf;
     62 }
     63 
     64 static int
     65 packack(unsigned char *buf, int blkno)
     66 {
     67 	buf[0] = ACK >> 8;
     68 	buf[1] = ACK & 0xff;
     69 	buf[2] = blkno >> 8;
     70 	buf[3] = blkno & 0xff;
     71 	return 4;
     72 }
     73 
     74 static int
     75 packdata(unsigned char *buf, int blkno)
     76 {
     77 	buf[0] = DATA >> 8;
     78 	buf[1] = DATA & 0xff;
     79 	buf[2] = blkno >> 8;
     80 	buf[3] = blkno & 0xff;
     81 	return 4;
     82 }
     83 
     84 static int
     85 unpackop(unsigned char *buf)
     86 {
     87 	return (buf[0] << 8) | (buf[1] & 0xff);
     88 }
     89 
     90 static int
     91 unpackblkno(unsigned char *buf)
     92 {
     93 	return (buf[2] << 8) | (buf[3] & 0xff);
     94 }
     95 
     96 static int
     97 unpackerrc(unsigned char *buf)
     98 {
     99 	int errc;
    100 
    101 	errc = (buf[2] << 8) | (buf[3] & 0xff);
    102 	if (errc < 0 || errc >= LEN(errtext))
    103 		eprintf("bad error code: %d\n", errc);
    104 	return errc;
    105 }
    106 
    107 static int
    108 writepkt(unsigned char *buf, int len)
    109 {
    110 	int n;
    111 
    112 	n = sendto(s, buf, len, 0, (struct sockaddr *)&to,
    113 	           tolen);
    114 	if (n < 0)
    115 		if (errno != EINTR)
    116 			eprintf("sendto:");
    117 	return n;
    118 }
    119 
    120 static int
    121 readpkt(unsigned char *buf, int len)
    122 {
    123 	int n;
    124 
    125 	n = recvfrom(s, buf, len, 0, (struct sockaddr *)&to,
    126 	             &tolen);
    127 	if (n < 0) {
    128 		if (errno != EINTR && errno != EWOULDBLOCK)
    129 			eprintf("recvfrom:");
    130 		timeout++;
    131 		if (timeout == NRETRIES)
    132 			eprintf("transfer timed out\n");
    133 	} else {
    134 		timeout = 0;
    135 	}
    136 	return n;
    137 }
    138 
    139 static void
    140 getfile(char *file)
    141 {
    142 	unsigned char buf[PKTSIZE];
    143 	int n, op, blkno, nextblkno = 1, done = 0;
    144 
    145 	state = RRQ;
    146 	for (;;) {
    147 		switch (state) {
    148 		case RRQ:
    149 			n = packreq(buf, RRQ, file, "octet");
    150 			writepkt(buf, n);
    151 			n = readpkt(buf, sizeof(buf));
    152 			if (n > 0) {
    153 				op = unpackop(buf);
    154 				if (op != DATA && op != ERR)
    155 					eprintf("bad opcode: %d\n", op);
    156 				state = op;
    157 			}
    158 			break;
    159 		case DATA:
    160 			n -= HDRSIZE;
    161 			if (n < 0)
    162 				eprintf("truncated packet\n");
    163 			blkno = unpackblkno(buf);
    164 			if (blkno == nextblkno) {
    165 				nextblkno++;
    166 				write(1, &buf[HDRSIZE], n);
    167 			}
    168 			if (n < BLKSIZE)
    169 				done = 1;
    170 			state = ACK;
    171 			break;
    172 		case ACK:
    173 			n = packack(buf, blkno);
    174 			writepkt(buf, n);
    175 			if (done)
    176 				return;
    177 			n = readpkt(buf, sizeof(buf));
    178 			if (n > 0) {
    179 				op = unpackop(buf);
    180 				if (op != DATA && op != ERR)
    181 					eprintf("bad opcode: %d\n", op);
    182 				state = op;
    183 			}
    184 			break;
    185 		case ERR:
    186 			eprintf("error: %s\n", errtext[unpackerrc(buf)]);
    187 		}
    188 	}
    189 }
    190 
    191 static void
    192 putfile(char *file)
    193 {
    194 	unsigned char inbuf[PKTSIZE], outbuf[PKTSIZE];
    195 	int inb, outb, op, blkno, nextblkno = 0, done = 0;
    196 
    197 	state = WWQ;
    198 	for (;;) {
    199 		switch (state) {
    200 		case WWQ:
    201 			outb = packreq(outbuf, WWQ, file, "octet");
    202 			writepkt(outbuf, outb);
    203 			inb = readpkt(inbuf, sizeof(inbuf));
    204 			if (inb > 0) {
    205 				op = unpackop(inbuf);
    206 				if (op != ACK && op != ERR)
    207 					eprintf("bad opcode: %d\n", op);
    208 				state = op;
    209 			}
    210 			break;
    211 		case DATA:
    212 			if (blkno == nextblkno) {
    213 				nextblkno++;
    214 				packdata(outbuf, nextblkno);
    215 				outb = read(0, &outbuf[HDRSIZE], BLKSIZE);
    216 				if (outb < BLKSIZE)
    217 					done = 1;
    218 			}
    219 			writepkt(outbuf, outb + HDRSIZE);
    220 			inb = readpkt(inbuf, sizeof(inbuf));
    221 			if (inb > 0) {
    222 				op = unpackop(inbuf);
    223 				if (op != ACK && op != ERR)
    224 					eprintf("bad opcode: %d\n", op);
    225 				state = op;
    226 			}
    227 			break;
    228 		case ACK:
    229 			if (inb < HDRSIZE)
    230 				eprintf("truncated packet\n");
    231 			blkno = unpackblkno(inbuf);
    232 			if (blkno == nextblkno)
    233 				if (done)
    234 					return;
    235 			state = DATA;
    236 			break;
    237 		case ERR:
    238 			eprintf("error: %s\n", errtext[unpackerrc(inbuf)]);
    239 		}
    240 	}
    241 }
    242 
    243 static void
    244 usage(void)
    245 {
    246 	eprintf("usage: %s -h host [-p port] [-x | -c] file\n", argv0);
    247 }
    248 
    249 int
    250 main(int argc, char *argv[])
    251 {
    252 	struct addrinfo hints, *res, *r;
    253 	struct timeval tv;
    254 	char *host = NULL, *port = "tftp";
    255 	void (*fn)(char *) = getfile;
    256 	int ret;
    257 
    258 	ARGBEGIN {
    259 	case 'h':
    260 		host = EARGF(usage());
    261 		break;
    262 	case 'p':
    263 		port = EARGF(usage());
    264 		break;
    265 	case 'x':
    266 		fn = getfile;
    267 		break;
    268 	case 'c':
    269 		fn = putfile;
    270 		break;
    271 	default:
    272 		usage();
    273 	} ARGEND
    274 
    275 	if (!host || !argc)
    276 		usage();
    277 
    278 	memset(&hints, 0, sizeof(hints));
    279 	hints.ai_family = AF_UNSPEC;
    280 	hints.ai_socktype = SOCK_DGRAM;
    281 	hints.ai_protocol = IPPROTO_UDP;
    282 	ret = getaddrinfo(host, port, &hints, &res);
    283 	if (ret)
    284 		eprintf("getaddrinfo: %s\n", gai_strerror(ret));
    285 
    286 	for (r = res; r; r = r->ai_next) {
    287 		if (r->ai_family != AF_INET &&
    288 		    r->ai_family != AF_INET6)
    289 			continue;
    290 		s = socket(r->ai_family, r->ai_socktype,
    291 		           r->ai_protocol);
    292 		if (s < 0)
    293 			continue;
    294 		break;
    295 	}
    296 	if (!r)
    297 		eprintf("cannot create socket\n");
    298 	memcpy(&to, r->ai_addr, r->ai_addrlen);
    299 	tolen = r->ai_addrlen;
    300 	freeaddrinfo(res);
    301 
    302 	tv.tv_sec = TIMEOUT_SEC;
    303 	tv.tv_usec = 0;
    304 	if (setsockopt(s, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)) < 0)
    305 		eprintf("setsockopt:");
    306 
    307 	fn(argv[0]);
    308 	return 0;
    309 }