From e5a659f33961032755f11a7fb0ca5d8ecbbd7ccd Mon Sep 17 00:00:00 2001 From: Jan Wagner Date: Mon, 19 Mar 2018 06:40:11 +0100 Subject: [PATCH] check_openvpn: Update to ba01f5f --- check_openvpn/check_openvpn | 240 +++++++++++++++++++++++++----------- 1 file changed, 167 insertions(+), 73 deletions(-) diff --git a/check_openvpn/check_openvpn b/check_openvpn/check_openvpn index b1e5c1b..dd931d6 100644 --- a/check_openvpn/check_openvpn +++ b/check_openvpn/check_openvpn @@ -2,10 +2,12 @@ # Check if an OpenVPN server runs on a given UDP or TCP port. # -# Copyright 2013 Roland Wolters -# Copyright 2016 Alarig Le Lay +# (C) 2013, Roland Wolters # -# Version 20160803 +# Contributors: +# - Andreas Stefl +# - Alarig Le Lay +# - Roland Wolters # # Permission is hereby granted, free of charge, to any person obtaining # a copy of this software and associated documentation files (the @@ -36,92 +38,168 @@ import socket import argparse import binascii -HMAC_CLIENT_KEY_START = 192 +P_CONTROL_HARD_RESET_CLIENT_V2 = 7 +P_CONTROL_HARD_RESET_SERVER_V2 = 8 +MAX_P_CONTROL_SIZE = 128 # TODO: adopt +MAX_CIPHER_KEY_LENGTH = 64 +MAX_HMAC_KEY_LENGTH = 64 BUFFER_SIZE = 1024 -ALGORITHMS_AVAILABLE = hashlib.algorithms_available \ - if hasattr(hashlib, "algorithms_available") else hashlib.algorithms +ALGORITHMS_AVAILABLE = hashlib.algorithms_available if hasattr(hashlib, "algorithms_available") else hashlib.algorithms def ok(msg): print('OK: %s' % msg) return 0 +def warning(msg): + print('WARN: %s' % msg) + return 1 + def critical(msg): print('CRIT: %s' % msg) return 2 -def buildpacket(tcp, key, digestmod): - packet = 1 - ts = int(time.time()) - session = os.urandom(8) +def build_p_control_hard_reset_client_v2(sid, digest, key): + # see openvpn source code "src/openvpn/crypto.c" function openvpn_decrypt_v1 + pid = 1 # packet id + ts = int(time.time()) # net time if key: - # hmac - h = hmac.new(key, digestmod=digestmod) - h.update(struct.pack('>I', packet)) # packet id + # generate hmac + h = hmac.new(key, digestmod=digest) + h.update(struct.pack('>I', pid)) # packet id h.update(struct.pack('>I', ts)) # net time - h.update(b'\x38') # type - h.update(session) # session id + h.update(struct.pack('>B', P_CONTROL_HARD_RESET_CLIENT_V2 << 3)) # packet type + h.update(sid) # session id h.update(struct.pack('>B', 0)) # message packet id array length h.update(struct.pack('>I', 0)) # message packet id - # packet + # build packet result = b'' - result += b'\x38' # type - result += session # session id + result += struct.pack('>B', P_CONTROL_HARD_RESET_CLIENT_V2 << 3) # packet type + result += sid # session id if key: result += h.digest() # hmac - result += struct.pack('>I', packet) # packet id - result += struct.pack('>I', ts) # net time + if key: result += struct.pack('>I', pid) # packet id + if key: result += struct.pack('>I', ts) # net time result += struct.pack('>B', 0) # message packet id array length result += struct.pack('>I', 0) # message packet id - if tcp: result = struct.pack('>H', len(result)) + result return result -def checkserver(host, port, tcp, timeout, key, digest): - packet = buildpacket(tcp, key, digest) - check = checkserver_tcp if tcp else checkserver_udp - return check(host, port, timeout, packet) +def validate_p_control_hard_reset_server_v2(packet, query_sid, digest, digest_size, key): + # see openvpn source code "src/openvpn/crypto.c" function openvpn_decrypt_v1 + # identify packet + if False: pass + elif len(packet) - struct.unpack('>B', packet[9:10])[0] * 4 == 14: plen = 0 # type sid mpida mpid + elif len(packet) - struct.unpack('>B', packet[9:10])[0] * 4 == 22: plen = 1 # type sid mpida rsid mpid + elif len(packet) - struct.unpack('>B', packet[17+digest_size:18+digest_size])[0] * 4 == 30+digest_size: plen = 2 # type sid hmac pid ts mpida rsid mpid + else: return 20 -def checkserver_udp(host, port, timeout, packet): + # parse packet + ptype = struct.unpack('>B', packet[:1])[0] # packet type + packet = packet[1:] + sid = packet[:8] # session id + packet = packet[8:] + if plen >= 2: + phmac = packet[:digest_size] # hmac + packet = packet[digest_size:] + pid = struct.unpack('>I', packet[:4])[0] # packet id + packet = packet[4:] + ts = struct.unpack('>I', packet[:4])[0] # net time + packet = packet[4:] + mpidlen = struct.unpack('>B', packet[:1])[0] # message packet id array length + packet = packet[1:] + mpidarray = [] # packet id array + for i in range(mpidlen): + mpidarray.append(struct.unpack('>I', packet[:4])[0]) # message packet id array element + packet = packet[4:] + if plen >= 1: + rsid = packet[:8] # remote session id + packet = packet[8:] + mpid = struct.unpack('>I', packet[:4])[0] # message packet id + + # validate packet + if ptype != P_CONTROL_HARD_RESET_SERVER_V2 << 3: return 20 + if mpid != 0: return 20 + if plen >= 1 and rsid != query_sid: return 20 + if plen >= 2 and key: + if pid != 1: return 20 + # generate hmac + h = hmac.new(key, digestmod=digest) + h.update(struct.pack('>I', pid)) # packet id + h.update(struct.pack('>I', ts)) # net time + h.update(struct.pack('>B', ptype)) # packet type + h.update(sid) # session id + h.update(struct.pack('>B', mpidlen)) # message packet id array length + for e in mpidarray: h.update(struct.pack('>I', e)) # message packet id array element + h.update(rsid) # remote session id + h.update(struct.pack('>I', mpid)) # message packet id + if phmac != h.digest(): return 10 + return 1 + return 0 + +def check(host, port, tcp, timeout, digest, digest_size, client_key, server_key, retrycount, validate): + sid = os.urandom(8) # session id + packet = build_p_control_hard_reset_client_v2(sid, digest, client_key) + query_server = query_tcp_server if tcp else query_udp_server + + try: + s = create_socket(host, port, tcp, timeout) + except socket.error: + return critical('Unable to create socket') + + try: + response = query_server(s, host, port, packet, retrycount) + except RuntimeError: + return critical('Invalid response') + except: + return critical('Not responding') + finally: + s.close() + + if response is None: + return critical('Not responding') + + # for debugging purpose + # response = binascii.hexlify(response) + if not validate: return ok('Responded') + valid = validate_p_control_hard_reset_server_v2(response, sid, digest, digest_size, server_key) + if valid == 0: return ok('Response validated') + if valid == 1: return ok('Response validated, checked HMAC') + if valid == 10: return warning('Invalid HMAC') + return critical('Invalid response') + +def create_socket(host, port, tcp, timeout): # thanks to glucas for the idea - try: - af, socktype, proto, canonname, sa = socket.getaddrinfo(host, port, \ - socket.AF_UNSPEC, socket.SOCK_DGRAM)[0] - s = socket.socket(af, socktype, proto) - s.settimeout(timeout) - except socket.error: - return critical('Unable to create UDP socket') + sock_type = socket.SOCK_STREAM if tcp else socket.SOCK_DGRAM + af, socktype, proto, canonname, sa = socket.getaddrinfo(host, port, socket.AF_UNSPEC, sock_type)[0] + s = socket.socket(af, socktype, proto) + s.settimeout(timeout) + return s - try: - s.sendto(packet, (host, port)) - data, _ = s.recvfrom(BUFFER_SIZE) - reply = binascii.hexlify(data) - return ok('OpenVPN UDP server response (hex): %s' % reply) - except: - return critical('OpenVPN UDP server not responding') - finally: - s.close() +def query_udp_server(s, host, port, packet, retrycount): + # Send up to 'retrycount' UDP packets with 'timeout' secs between each. + # Return data after receiving first UDP packet. + for i in range(retrycount): + try: + s.sendto(packet, (host, port)) + # TODO: check response address + data, _ = s.recvfrom(BUFFER_SIZE) + return data + except socket.timeout: + pass + except: + return None + return None -def checkserver_tcp(host, port, timeout, packet): - try: - af, socktype, proto, canonname, sa = socket.getaddrinfo(host, port, \ - socket.AF_UNSPEC, socket.SOCK_STREAM)[0] - s = socket.socket(af, socktype, proto) - s.settimeout(timeout) - except socket.error: - return critical('Unable to create TCP socket') - - try: - s.connect((host, port)) - s.send(packet) - data = s.recv(BUFFER_SIZE) - if len(data) <= 0: raise RuntimeError - reply = binascii.hexlify(data) - return ok('OpenVPN TCP server response (hex): %s' % reply) - except: - return critical('OpenVPN TCP server not responding') - finally: - s.close() +def query_tcp_server(s, host, port, packet, retrycount): + # ignore retrycount + s.connect((host, port)) + s.send(struct.pack('>H', len(packet)) + packet) + length = struct.unpack('>H', s.recv(2))[0] + if length > MAX_P_CONTROL_SIZE: raise RuntimeError + data = s.recv(length) + if len(data) != length: raise RuntimeError + return data def readkey(path): key = None @@ -135,17 +213,23 @@ def readkey(path): return None index_start += 2 key = key[index_start:index_end].replace('\n', '').replace('\r', '') + key = binascii.unhexlify(key) return key def optionsparser(argv=None): parser = argparse.ArgumentParser() parser.add_argument('-p', '--port', help='set port number (default is %(default)d)', type=int, default=1194) parser.add_argument('-t', '--tcp', help='use tcp instead of udp', action='store_true') - parser.add_argument('--timeout', help='set timeout (default is %(default)d)', type=int, default=5) - parser.add_argument('--digest', help='set HMAC digest (default is "%(default)s")', default='sha1') + parser.add_argument('--timeout', help='set timeout in seconds, for udp counted per packet (default is %(default)d)', type=int, default=2) + parser.add_argument('--digest', help='set digest algorithm (default is "%(default)s")', default='sha1') parser.add_argument('--digest-size', help='set HMAC digest size', type=int) - parser.add_argument('--digest-key', help='set HMAC key') + parser.add_argument('--digest-key-client', help='set client HMAC key') + parser.add_argument('--digest-key-server', help='set server HMAC key for packet validation') parser.add_argument('--tls-auth', help='set tls-auth file') + # TODO: direction argument (normal, inverse) + parser.add_argument('--tls-auth-inverse', help='set tls-auth file direction to inverse (1)', action='store_true') + parser.add_argument('--retrycount', help='number of udp retries before giving up (default is %(default)d)', type=int, default=3) + parser.add_argument('--no-validation', help='do not validate response', action='store_true') parser.add_argument('host', help='the OpenVPN host name or IP') return parser.parse_args(argv) @@ -154,10 +238,13 @@ def main(argv=None): if args.digest_size and args.digest_size < 0: critical('digest size must be positive') - if args.tls_auth and args.digest_key: + if args.retrycount < 1: + critical('retry count must be positive') + if args.tls_auth and (args.digest_key_client or args.digest_key_server): critical('--tls-auth cannot go with --digest-key') - key = args.digest_key + client_key = binascii.unhexlify(args.digest_key_client) if args.digest_key_client else None + server_key = binascii.unhexlify(args.digest_key_server) if args.digest_key_server else None digest = args.digest digest_size = args.digest_size @@ -171,15 +258,22 @@ def main(argv=None): return critical('digest creation failed') if args.tls_auth: + # see openvpn source code "src/openvpn/crypto.h", "src/openvpn/crypto.c" and "src/openvpn/crypto_backend.h" + # 64 byte cipher direction 0, 64 byte hmac direction 0, 64 byte cipher direction 1, 64 byte hmac direction 1 (=> 2048 bit) key = readkey(args.tls_auth) if key == None: return critical('cannot read tls auth file') - index_start = HMAC_CLIENT_KEY_START * 2 - index_end = (HMAC_CLIENT_KEY_START + digest_size) * 2 - key = key[index_start:index_end] + index_start = MAX_CIPHER_KEY_LENGTH + MAX_HMAC_KEY_LENGTH + MAX_CIPHER_KEY_LENGTH + index_end = index_start + MAX_HMAC_KEY_LENGTH + client_key = key[index_start:index_end] + index_start = MAX_CIPHER_KEY_LENGTH + index_end = index_start + MAX_HMAC_KEY_LENGTH + server_key = key[index_start:index_end] + if args.tls_auth_inverse: client_key, server_key = server_key, client_key + # reduce key size to required size + client_key = client_key[:digest_size] + server_key = server_key[:digest_size] - if key: key = binascii.unhexlify(key) - - return checkserver(args.host, args.port, args.tcp, args.timeout, key, digest) + return check(args.host, args.port, args.tcp, args.timeout, digest, digest_size, client_key, server_key, args.retrycount, not args.no_validation) if __name__ == '__main__': code = main()