check_openvpn: Update to ba01f5f
This commit is contained in:
		
							parent
							
								
									f134dbfd1c
								
							
						
					
					
						commit
						e5a659f339
					
				
					 1 changed files with 167 additions and 73 deletions
				
			
		| 
						 | 
				
			
			@ -2,10 +2,12 @@
 | 
			
		|||
 | 
			
		||||
# Check if an OpenVPN server runs on a given UDP or TCP port.
 | 
			
		||||
#
 | 
			
		||||
# Copyright 2013 Roland Wolters
 | 
			
		||||
# Copyright 2016 Alarig Le Lay
 | 
			
		||||
# (C) 2013, Roland Wolters
 | 
			
		||||
#
 | 
			
		||||
# Version 20160803
 | 
			
		||||
# Contributors:
 | 
			
		||||
# - Andreas Stefl
 | 
			
		||||
# - Alarig Le Lay
 | 
			
		||||
# - Roland Wolters
 | 
			
		||||
#
 | 
			
		||||
# Permission is hereby granted, free of charge, to any person obtaining
 | 
			
		||||
# a copy of this software and associated documentation files (the
 | 
			
		||||
| 
						 | 
				
			
			@ -36,92 +38,168 @@ import socket
 | 
			
		|||
import argparse
 | 
			
		||||
import binascii
 | 
			
		||||
 | 
			
		||||
HMAC_CLIENT_KEY_START = 192
 | 
			
		||||
P_CONTROL_HARD_RESET_CLIENT_V2 = 7
 | 
			
		||||
P_CONTROL_HARD_RESET_SERVER_V2 = 8
 | 
			
		||||
MAX_P_CONTROL_SIZE = 128 # TODO: adopt
 | 
			
		||||
MAX_CIPHER_KEY_LENGTH = 64
 | 
			
		||||
MAX_HMAC_KEY_LENGTH = 64
 | 
			
		||||
BUFFER_SIZE = 1024
 | 
			
		||||
 | 
			
		||||
ALGORITHMS_AVAILABLE = hashlib.algorithms_available \
 | 
			
		||||
    if hasattr(hashlib, "algorithms_available") else hashlib.algorithms
 | 
			
		||||
ALGORITHMS_AVAILABLE = hashlib.algorithms_available if hasattr(hashlib, "algorithms_available") else hashlib.algorithms
 | 
			
		||||
 | 
			
		||||
def ok(msg):
 | 
			
		||||
    print('OK: %s' % msg)
 | 
			
		||||
    return 0
 | 
			
		||||
 | 
			
		||||
def warning(msg):
 | 
			
		||||
    print('WARN: %s' % msg)
 | 
			
		||||
    return 1
 | 
			
		||||
 | 
			
		||||
def critical(msg):
 | 
			
		||||
    print('CRIT: %s' % msg)
 | 
			
		||||
    return 2
 | 
			
		||||
 | 
			
		||||
def buildpacket(tcp, key, digestmod):
 | 
			
		||||
    packet = 1
 | 
			
		||||
    ts = int(time.time())
 | 
			
		||||
    session = os.urandom(8)
 | 
			
		||||
def build_p_control_hard_reset_client_v2(sid, digest, key):
 | 
			
		||||
    # see openvpn source code "src/openvpn/crypto.c" function openvpn_decrypt_v1
 | 
			
		||||
    pid = 1 # packet id
 | 
			
		||||
    ts = int(time.time())  # net time
 | 
			
		||||
 | 
			
		||||
    if key:
 | 
			
		||||
        # hmac
 | 
			
		||||
        h = hmac.new(key, digestmod=digestmod)
 | 
			
		||||
        h.update(struct.pack('>I', packet)) # packet id
 | 
			
		||||
        # generate hmac
 | 
			
		||||
        h = hmac.new(key, digestmod=digest)
 | 
			
		||||
        h.update(struct.pack('>I', pid)) # packet id
 | 
			
		||||
        h.update(struct.pack('>I', ts)) # net time
 | 
			
		||||
        h.update(b'\x38') # type
 | 
			
		||||
        h.update(session) # session id
 | 
			
		||||
        h.update(struct.pack('>B', P_CONTROL_HARD_RESET_CLIENT_V2 << 3)) # packet type
 | 
			
		||||
        h.update(sid) # session id
 | 
			
		||||
        h.update(struct.pack('>B', 0)) # message packet id array length
 | 
			
		||||
        h.update(struct.pack('>I', 0)) # message packet id
 | 
			
		||||
 | 
			
		||||
    # packet
 | 
			
		||||
    # build packet
 | 
			
		||||
    result = b''
 | 
			
		||||
    result += b'\x38' # type
 | 
			
		||||
    result += session # session id
 | 
			
		||||
    result += struct.pack('>B', P_CONTROL_HARD_RESET_CLIENT_V2 << 3) # packet type
 | 
			
		||||
    result += sid # session id
 | 
			
		||||
    if key: result += h.digest() # hmac
 | 
			
		||||
    result += struct.pack('>I', packet) # packet id
 | 
			
		||||
    result += struct.pack('>I', ts) # net time
 | 
			
		||||
    if key: result += struct.pack('>I', pid) # packet id
 | 
			
		||||
    if key: result += struct.pack('>I', ts) # net time
 | 
			
		||||
    result += struct.pack('>B', 0) # message packet id array length
 | 
			
		||||
    result += struct.pack('>I', 0) # message packet id
 | 
			
		||||
    if tcp: result = struct.pack('>H', len(result)) + result
 | 
			
		||||
    return result
 | 
			
		||||
 | 
			
		||||
def checkserver(host, port, tcp, timeout, key, digest):
 | 
			
		||||
    packet = buildpacket(tcp, key, digest)
 | 
			
		||||
    check = checkserver_tcp if tcp else checkserver_udp
 | 
			
		||||
    return check(host, port, timeout, packet)
 | 
			
		||||
def validate_p_control_hard_reset_server_v2(packet, query_sid, digest, digest_size, key):
 | 
			
		||||
    # see openvpn source code "src/openvpn/crypto.c" function openvpn_decrypt_v1
 | 
			
		||||
    # identify packet
 | 
			
		||||
    if False: pass
 | 
			
		||||
    elif len(packet) - struct.unpack('>B', packet[9:10])[0] * 4 == 14: plen = 0 # type sid mpida mpid
 | 
			
		||||
    elif len(packet) - struct.unpack('>B', packet[9:10])[0] * 4 == 22: plen = 1 # type sid mpida rsid mpid
 | 
			
		||||
    elif len(packet) - struct.unpack('>B', packet[17+digest_size:18+digest_size])[0] * 4 == 30+digest_size: plen = 2 # type sid hmac pid ts mpida rsid mpid
 | 
			
		||||
    else: return 20
 | 
			
		||||
 | 
			
		||||
def checkserver_udp(host, port, timeout, packet):
 | 
			
		||||
    # parse packet
 | 
			
		||||
    ptype = struct.unpack('>B', packet[:1])[0] # packet type
 | 
			
		||||
    packet = packet[1:]
 | 
			
		||||
    sid = packet[:8] # session id
 | 
			
		||||
    packet = packet[8:]
 | 
			
		||||
    if plen >= 2:
 | 
			
		||||
        phmac = packet[:digest_size] # hmac
 | 
			
		||||
        packet = packet[digest_size:]
 | 
			
		||||
        pid = struct.unpack('>I', packet[:4])[0] # packet id
 | 
			
		||||
        packet = packet[4:]
 | 
			
		||||
        ts = struct.unpack('>I', packet[:4])[0] # net time
 | 
			
		||||
        packet = packet[4:]
 | 
			
		||||
    mpidlen = struct.unpack('>B', packet[:1])[0] # message packet id array length
 | 
			
		||||
    packet = packet[1:]
 | 
			
		||||
    mpidarray = [] # packet id array
 | 
			
		||||
    for i in range(mpidlen):
 | 
			
		||||
        mpidarray.append(struct.unpack('>I', packet[:4])[0]) # message packet id array element
 | 
			
		||||
        packet = packet[4:]
 | 
			
		||||
    if plen >= 1:
 | 
			
		||||
        rsid = packet[:8] # remote session id
 | 
			
		||||
        packet = packet[8:]
 | 
			
		||||
    mpid = struct.unpack('>I', packet[:4])[0] # message packet id
 | 
			
		||||
 | 
			
		||||
    # validate packet
 | 
			
		||||
    if ptype != P_CONTROL_HARD_RESET_SERVER_V2 << 3: return 20
 | 
			
		||||
    if mpid != 0: return 20
 | 
			
		||||
    if plen >= 1 and rsid != query_sid: return 20
 | 
			
		||||
    if plen >= 2 and key:
 | 
			
		||||
        if pid != 1: return 20
 | 
			
		||||
        # generate hmac
 | 
			
		||||
        h = hmac.new(key, digestmod=digest)
 | 
			
		||||
        h.update(struct.pack('>I', pid)) # packet id
 | 
			
		||||
        h.update(struct.pack('>I', ts)) # net time
 | 
			
		||||
        h.update(struct.pack('>B', ptype)) # packet type
 | 
			
		||||
        h.update(sid) # session id
 | 
			
		||||
        h.update(struct.pack('>B', mpidlen)) # message packet id array length
 | 
			
		||||
        for e in mpidarray: h.update(struct.pack('>I', e)) # message packet id array element
 | 
			
		||||
        h.update(rsid) # remote session id
 | 
			
		||||
        h.update(struct.pack('>I', mpid)) # message packet id
 | 
			
		||||
        if phmac != h.digest(): return 10
 | 
			
		||||
        return 1
 | 
			
		||||
    return 0
 | 
			
		||||
 | 
			
		||||
def check(host, port, tcp, timeout, digest, digest_size, client_key, server_key, retrycount, validate):
 | 
			
		||||
    sid = os.urandom(8) # session id
 | 
			
		||||
    packet = build_p_control_hard_reset_client_v2(sid, digest, client_key)
 | 
			
		||||
    query_server = query_tcp_server if tcp else query_udp_server
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        s = create_socket(host, port, tcp, timeout)
 | 
			
		||||
    except socket.error:
 | 
			
		||||
        return critical('Unable to create socket')
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        response = query_server(s, host, port, packet, retrycount)
 | 
			
		||||
    except RuntimeError:
 | 
			
		||||
        return critical('Invalid response')
 | 
			
		||||
    except:
 | 
			
		||||
        return critical('Not responding')
 | 
			
		||||
    finally:
 | 
			
		||||
        s.close()
 | 
			
		||||
 | 
			
		||||
    if response is None:
 | 
			
		||||
        return critical('Not responding')
 | 
			
		||||
 | 
			
		||||
    # for debugging purpose
 | 
			
		||||
    # response = binascii.hexlify(response)
 | 
			
		||||
    if not validate: return ok('Responded')
 | 
			
		||||
    valid = validate_p_control_hard_reset_server_v2(response, sid, digest, digest_size, server_key)
 | 
			
		||||
    if valid == 0: return ok('Response validated')
 | 
			
		||||
    if valid == 1: return ok('Response validated, checked HMAC')
 | 
			
		||||
    if valid == 10: return warning('Invalid HMAC')
 | 
			
		||||
    return critical('Invalid response')
 | 
			
		||||
 | 
			
		||||
def create_socket(host, port, tcp, timeout):
 | 
			
		||||
    # thanks to glucas for the idea
 | 
			
		||||
    try:
 | 
			
		||||
        af, socktype, proto, canonname, sa = socket.getaddrinfo(host, port, \
 | 
			
		||||
            socket.AF_UNSPEC, socket.SOCK_DGRAM)[0]
 | 
			
		||||
        s = socket.socket(af, socktype, proto)
 | 
			
		||||
        s.settimeout(timeout)
 | 
			
		||||
    except socket.error:
 | 
			
		||||
        return critical('Unable to create UDP socket')
 | 
			
		||||
    sock_type = socket.SOCK_STREAM if tcp else socket.SOCK_DGRAM
 | 
			
		||||
    af, socktype, proto, canonname, sa = socket.getaddrinfo(host, port, socket.AF_UNSPEC, sock_type)[0]
 | 
			
		||||
    s = socket.socket(af, socktype, proto)
 | 
			
		||||
    s.settimeout(timeout)
 | 
			
		||||
    return s
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        s.sendto(packet, (host, port))
 | 
			
		||||
        data, _ = s.recvfrom(BUFFER_SIZE)
 | 
			
		||||
        reply = binascii.hexlify(data)
 | 
			
		||||
        return ok('OpenVPN UDP server response (hex): %s' % reply)
 | 
			
		||||
    except:
 | 
			
		||||
        return critical('OpenVPN UDP server not responding')
 | 
			
		||||
    finally:
 | 
			
		||||
        s.close()
 | 
			
		||||
def query_udp_server(s, host, port, packet, retrycount):
 | 
			
		||||
    # Send up to 'retrycount' UDP packets with 'timeout' secs between each.
 | 
			
		||||
    # Return data after receiving first UDP packet.
 | 
			
		||||
    for i in range(retrycount):
 | 
			
		||||
        try:
 | 
			
		||||
            s.sendto(packet, (host, port))
 | 
			
		||||
            # TODO: check response address
 | 
			
		||||
            data, _ = s.recvfrom(BUFFER_SIZE)
 | 
			
		||||
            return data
 | 
			
		||||
        except socket.timeout:
 | 
			
		||||
            pass
 | 
			
		||||
        except:
 | 
			
		||||
            return None
 | 
			
		||||
    return None
 | 
			
		||||
 | 
			
		||||
def checkserver_tcp(host, port, timeout, packet):
 | 
			
		||||
    try:
 | 
			
		||||
        af, socktype, proto, canonname, sa = socket.getaddrinfo(host, port, \
 | 
			
		||||
            socket.AF_UNSPEC, socket.SOCK_STREAM)[0]
 | 
			
		||||
        s = socket.socket(af, socktype, proto)
 | 
			
		||||
        s.settimeout(timeout)
 | 
			
		||||
    except socket.error:
 | 
			
		||||
        return critical('Unable to create TCP socket')
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        s.connect((host, port))
 | 
			
		||||
        s.send(packet)
 | 
			
		||||
        data = s.recv(BUFFER_SIZE)
 | 
			
		||||
        if len(data) <= 0: raise RuntimeError
 | 
			
		||||
        reply = binascii.hexlify(data)
 | 
			
		||||
        return ok('OpenVPN TCP server response (hex): %s' % reply)
 | 
			
		||||
    except:
 | 
			
		||||
        return critical('OpenVPN TCP server not responding')
 | 
			
		||||
    finally:
 | 
			
		||||
        s.close()
 | 
			
		||||
def query_tcp_server(s, host, port, packet, retrycount):
 | 
			
		||||
    # ignore retrycount
 | 
			
		||||
    s.connect((host, port))
 | 
			
		||||
    s.send(struct.pack('>H', len(packet)) + packet)
 | 
			
		||||
    length = struct.unpack('>H', s.recv(2))[0]
 | 
			
		||||
    if length > MAX_P_CONTROL_SIZE: raise RuntimeError
 | 
			
		||||
    data = s.recv(length)
 | 
			
		||||
    if len(data) != length: raise RuntimeError
 | 
			
		||||
    return data
 | 
			
		||||
 | 
			
		||||
def readkey(path):
 | 
			
		||||
    key = None
 | 
			
		||||
| 
						 | 
				
			
			@ -135,17 +213,23 @@ def readkey(path):
 | 
			
		|||
        return None
 | 
			
		||||
    index_start += 2
 | 
			
		||||
    key = key[index_start:index_end].replace('\n', '').replace('\r', '')
 | 
			
		||||
    key = binascii.unhexlify(key)
 | 
			
		||||
    return key
 | 
			
		||||
 | 
			
		||||
def optionsparser(argv=None):
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    parser.add_argument('-p', '--port', help='set port number (default is %(default)d)', type=int, default=1194)
 | 
			
		||||
    parser.add_argument('-t', '--tcp', help='use tcp instead of udp', action='store_true')
 | 
			
		||||
    parser.add_argument('--timeout', help='set timeout (default is %(default)d)', type=int, default=5)
 | 
			
		||||
    parser.add_argument('--digest', help='set HMAC digest (default is "%(default)s")', default='sha1')
 | 
			
		||||
    parser.add_argument('--timeout', help='set timeout in seconds, for udp counted per packet (default is %(default)d)', type=int, default=2)
 | 
			
		||||
    parser.add_argument('--digest', help='set digest algorithm (default is "%(default)s")', default='sha1')
 | 
			
		||||
    parser.add_argument('--digest-size', help='set HMAC digest size', type=int)
 | 
			
		||||
    parser.add_argument('--digest-key', help='set HMAC key')
 | 
			
		||||
    parser.add_argument('--digest-key-client', help='set client HMAC key')
 | 
			
		||||
    parser.add_argument('--digest-key-server', help='set server HMAC key for packet validation')
 | 
			
		||||
    parser.add_argument('--tls-auth', help='set tls-auth file')
 | 
			
		||||
    # TODO: direction argument (normal, inverse)
 | 
			
		||||
    parser.add_argument('--tls-auth-inverse', help='set tls-auth file direction to inverse (1)', action='store_true')
 | 
			
		||||
    parser.add_argument('--retrycount', help='number of udp retries before giving up (default is %(default)d)', type=int, default=3)
 | 
			
		||||
    parser.add_argument('--no-validation', help='do not validate response', action='store_true')
 | 
			
		||||
    parser.add_argument('host', help='the OpenVPN host name or IP')
 | 
			
		||||
    return parser.parse_args(argv)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -154,10 +238,13 @@ def main(argv=None):
 | 
			
		|||
 | 
			
		||||
    if args.digest_size and args.digest_size < 0:
 | 
			
		||||
        critical('digest size must be positive')
 | 
			
		||||
    if args.tls_auth and args.digest_key:
 | 
			
		||||
    if args.retrycount < 1:
 | 
			
		||||
        critical('retry count must be positive')
 | 
			
		||||
    if args.tls_auth and (args.digest_key_client or args.digest_key_server):
 | 
			
		||||
        critical('--tls-auth cannot go with --digest-key')
 | 
			
		||||
 | 
			
		||||
    key = args.digest_key
 | 
			
		||||
    client_key = binascii.unhexlify(args.digest_key_client) if args.digest_key_client else None
 | 
			
		||||
    server_key = binascii.unhexlify(args.digest_key_server) if args.digest_key_server else None
 | 
			
		||||
    digest = args.digest
 | 
			
		||||
    digest_size = args.digest_size
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -171,15 +258,22 @@ def main(argv=None):
 | 
			
		|||
        return critical('digest creation failed')
 | 
			
		||||
 | 
			
		||||
    if args.tls_auth:
 | 
			
		||||
        # see openvpn source code "src/openvpn/crypto.h", "src/openvpn/crypto.c" and "src/openvpn/crypto_backend.h"
 | 
			
		||||
        # 64 byte cipher direction 0, 64 byte hmac direction 0, 64 byte cipher direction 1, 64 byte hmac direction 1 (=> 2048 bit)
 | 
			
		||||
        key = readkey(args.tls_auth)
 | 
			
		||||
        if key == None: return critical('cannot read tls auth file')
 | 
			
		||||
        index_start = HMAC_CLIENT_KEY_START * 2
 | 
			
		||||
        index_end = (HMAC_CLIENT_KEY_START + digest_size) * 2
 | 
			
		||||
        key = key[index_start:index_end]
 | 
			
		||||
        index_start = MAX_CIPHER_KEY_LENGTH + MAX_HMAC_KEY_LENGTH + MAX_CIPHER_KEY_LENGTH
 | 
			
		||||
        index_end = index_start + MAX_HMAC_KEY_LENGTH
 | 
			
		||||
        client_key = key[index_start:index_end]
 | 
			
		||||
        index_start = MAX_CIPHER_KEY_LENGTH
 | 
			
		||||
        index_end = index_start + MAX_HMAC_KEY_LENGTH
 | 
			
		||||
        server_key = key[index_start:index_end]
 | 
			
		||||
        if args.tls_auth_inverse: client_key, server_key = server_key, client_key
 | 
			
		||||
        # reduce key size to required size
 | 
			
		||||
        client_key = client_key[:digest_size]
 | 
			
		||||
        server_key = server_key[:digest_size]
 | 
			
		||||
 | 
			
		||||
    if key: key = binascii.unhexlify(key)
 | 
			
		||||
 | 
			
		||||
    return checkserver(args.host, args.port, args.tcp, args.timeout, key, digest)
 | 
			
		||||
    return check(args.host, args.port, args.tcp, args.timeout, digest, digest_size, client_key, server_key, args.retrycount, not args.no_validation)
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    code = main()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue