#!/usr/libexec/platform-python

from ctypes import CDLL
from json import dump, load
from mmap import ACCESS_READ, mmap
from os import getpid, listdir, path, sysconf, sysconf_names, times
from re import search
from signal import SIGHUP, SIGINT, SIGQUIT, SIGTERM, signal
from sre_constants import error as invalid_re
from sys import argv, exit, stderr, stdout
from time import monotonic, process_time, sleep


def format_time(t):
    """
    """
    total_s = int(t)

    if total_s < 60:
        return '{}s'.format(round(t, 1))

    if total_s < 3600:
        total_m = total_s // 60
        mod_s = total_s % 60
        return '{}min {}s'.format(total_m, mod_s)

    if total_s < 86400:
        total_m = total_s // 60
        mod_s = total_s % 60
        total_h = total_m // 60
        mod_m = total_m % 60
        return '{}h {}min {}s'.format(total_h, mod_m, mod_s)

    total_m = total_s // 60
    mod_s = total_s % 60
    total_h = total_m // 60
    mod_m = total_m % 60
    total_d = total_h // 24
    mod_h = total_h % 24
    return '{}d {}h {}min {}s'.format(total_d, mod_h, mod_m, mod_s)


def valid_re(reg_exp):
    """
    Validate regular expression.
    """
    try:
        search(reg_exp, '')
    except invalid_re:
        errprint('Invalid config: invalid regexp: {}'.format(reg_exp))
        exit(1)


def errprint(*text):
    """
    """
    print(*text, file=stderr, flush=True)


def string_to_float_convert_test(string):
    """
    Try to interprete string values as floats.
    """
    try:
        return float(string)
    except ValueError:
        return None


def mlockall():
    """
    Lock process memory.
    """
    MCL_FUTURE = 2

    libc = CDLL(None, use_errno=True)

    result = libc.mlockall(MCL_FUTURE)

    if result != 0:
        errprint('ERROR: cannot lock process memory: [Errno {}]'.format(
            result))
        errprint('Exit.')
        exit(1)
    else:
        if debug_self:
            print('process memory locked with MCL_FUTURE')


def signal_handler(signum, frame):
    """
    Handle signals: close fd, dump d and t, exit.
    """
    print('Got signal {}'.format(signum))
    print('Unlocking files and dumping snapshots')
    unlock_files(lock_dict)

    lock_t = round(monotonic() - var_dict['lock_t0'], 1)
    # print('lock_t:', lock_t)

    dump_d['t'] = lock_t
    dump_d['d'] = d
    jdump(dump_path, dump_d)

    if debug_self:
        mm_debug()
        self_rss = get_self_rss()
        print('self rss: {}M'.format(round(self_rss / MIB, 1)))

    cpu()

    print('Exit.')
    exit()


def get_pid_list():
    """
    """
    pid_list = []
    for pid in listdir('/proc'):
        if not pid[0].isdecimal():
            continue
        pid_list.append(pid)
    pid_list.remove(self_pid)
    return pid_list


def get_uptime():
    """
    """
    with open('/proc/uptime', 'rb', buffering=0) as f:
        return float(f.read().decode().split(' ')[0])


def get_uniq_id(pid):
    """
    """
    try:
        with open('/proc/' + pid + '/stat', 'rb', buffering=0) as f:
            x, _, y = f.read().decode('utf-8', 'ignore').rpartition(')')

            if y[-7:] == ' 0 0 0\n':
                return None  # skip kthreads

            if lock_only_critical:

                lock_ok = False

                name = x.partition(' (')[2]
                if name in name_set:
                    lock_ok = True
                    if debug_map:
                        print('found process with critical name:', pid, name)

                if not lock_ok and check_cgroup:
                    cgroup2 = pid_to_cgroup_v2(pid)

                    for cgroup_re in cgroup_set:

                        if search(cgroup_re, cgroup2) is not None:
                            lock_ok = True
                            if debug_map:
                                print('found process from critical cgroup:',
                                      pid, name, cgroup2)
                            break

                if not lock_ok:
                    return None

            starttime = y.split(' ')[20]
            uniq_id = starttime + '+' + x
            return starttime, uniq_id
    except (FileNotFoundError, ProcessLookupError):
        return None


def get_current_set():
    """
    """
    m0 = monotonic()
    p0 = process_time()

    if debug_map:
        print('Looking for mapped files...')

    new_map_set = set()

    pid_list = get_pid_list()

    uptime = get_uptime()

    uniq_set = set()

    for pid in pid_list:

        s = get_uniq_id(pid)

        if s is None:
            continue

        uniq_map_set = set()

        starttime, uniq_id = s

        uniq_set.add(uniq_id)

        lifetime = uptime - float(starttime) / SC_CLK_TCK

        if uniq_id in uniq_map_d:
            new_map_set.update(uniq_map_d[uniq_id])
            continue

        if debug_map:
            print('finding mapped files for new process:', uniq_id)

        maps = '/proc/' + pid + '/maps'

        try:
            try:
                with open(maps, 'rb', buffering=0) as f:
                    lines_list = f.read().decode('utf-8', 'ignore').split('\n')
            except PermissionError as e:
                errprint(e)
                exit(1)

            for line in lines_list:
                w_root = line.partition('/')[2]
                uniq_map_set.add(w_root)

            new_map_set.update(uniq_map_set)

            if lifetime >= min_save_lifetime:
                uniq_map_d[uniq_id] = uniq_map_set

        except FileNotFoundError as e:
            errprint(e)
            continue

    uniq_map_d_set = set(uniq_map_d)
    dead_uniq_set = uniq_map_d_set - uniq_set
    for uniq_id in dead_uniq_set:
        del uniq_map_d[uniq_id]

    new_map_set.discard('')

    final_map_set = set()

    for i in new_map_set:
        final_map_set.add('/' + i)

    current_map_set = set()
    for pathname in final_map_set:
        if path.exists(pathname):
            current_map_set.add(pathname)
        else:
            if debug_map:
                print('skip:', pathname)

    if debug_map:
        list_d = list(current_map_set)
        list_d.sort()
        for pathname in list_d:
            print('mapped:', pathname)

    m1 = monotonic()
    p1 = process_time()

    if debug_map:
        print('Found {} mapped files in {}s (process time: {}s)'.format(len(
            current_map_set), round(m1 - m0, 3), round(p1 - p0, 3)))

    return current_map_set


def get_sorted(rp_set):
    """
    """
    di = dict()
    for rp in rp_set:
        rp = str(rp)
        if search(lock_path_regex, rp) is not None:
            try:
                s = path.getsize(rp)
                if s > 0:
                    di[rp] = s
            except FileNotFoundError as e:
                if debug_map:
                    print(e)
        else:
            if debug_map:
                print("skip (doesn't match $LOCK_PATH_REGEX): " + rp)
    sorted_list = list(di.items())
    sorted_list.sort(key=lambda i: i[1])
    return sorted_list


def get_sorted_locked():
    """
    """
    di = dict()
    for rp in lock_dict:
        size = lock_dict[rp][1]
        di[rp] = size
    sorted_list = list(di.items())
    sorted_list.sort(key=lambda i: i[1], reverse=True)
    return sorted_list


def lock_files(rp_set):
    """
    """
    if debug_lock:
        print('locking new files...')

    lock_counter = 0

    sorted_list = get_sorted(rp_set)

    sorted_locked = get_sorted_locked()

    len_sorted_locked = len(sorted_locked)

    last_index = 0

    for f_realpath, size in sorted_list:

        # skip large files
        if size > max_file_size:
            if debug_lock:
                print(
                    'skip (file size {}M > $MAX_FILE_SIZE_MIB) {}'.format(
                        round(size / MIB, 1), f_realpath))
                continue
            else:
                break

        if f_realpath in lock_dict:
            continue

        cap = max_total_size
        locked = get_total_size()
        avail = cap - locked

        total_size = get_total_size()

        avail = max_total_size - total_size

        if avail < size:

            if len_sorted_locked == 0:

                if debug_lock:
                    print('skip (total_size ({}M) + size ({}M) > max_total_s'
                          'ize) {}'.format(
                              round(total_size / MIB, 1), round(
                                  size / MIB, 1), f_realpath))
                    continue
                else:
                    break

            else:
                old_f, old_s = sorted_locked[last_index]

                if size < old_s:

                    avail_future = avail + old_s - size

                    if avail_future < size:
                        if debug_lock:
                            print('skip (total_size ({}M) + size ({}M) > max'
                                  '_total_size) {}'.format(
                                      round(total_size / MIB, 1), round(
                                          size / MIB, 1), f_realpath))
                            continue
                        else:
                            break

                    lock_dict[old_f][0].close()
                    del lock_dict[old_f]

                    last_index += 1

                    try:
                        print(
                            'locking ({}M) {}'.format(
                                round(
                                    size / MIB,
                                    1),
                                f_realpath))

                        with open(f_realpath, 'rb') as f:
                            mm = mmap(f.fileno(), 0, access=ACCESS_READ)

                        mm_len = len(mm)

                        if mm_len != size:
                            print('W: mm_len != size:', f_realpath)

                        lock_dict[f_realpath] = (mm, mm_len)

                        lock_counter += 1
                        continue

                    except OSError as e:
                        errprint(e)
                        break

                if debug_lock:
                    print('skip (total_size ({}M) + size ({}M) > max_total_s'
                          'ize) {}'.format(
                              round(total_size / MIB, 1), round(
                                  size / MIB, 1), f_realpath))
                    continue
                else:
                    break

        try:

            if debug_lock:
                print('locking ({}M) {}'.format(
                    round(size / MIB, 1), f_realpath))

            with open(f_realpath, 'rb') as f:
                mm = mmap(f.fileno(), 0, access=ACCESS_READ)

            mm_len = len(mm)

            if mm_len != size:
                print('W: mm_len != size:', f_realpath)

            lock_dict[f_realpath] = (mm, mm_len)

            lock_counter += 1

        except OSError as e:
            errprint(e)
            break

    if debug_1:
        mm_debug()


def unlock_files(elements):
    """
    """
    if len(elements) > 0:

        del_set = set()

        for f_realpath in elements:

            try:
                lock_dict[f_realpath][0].close()

                del_set.add(f_realpath)

                if debug_lock:
                    print('unlocked:', f_realpath)
            except KeyError:
                if debug_lock:
                    print('key error:', f_realpath)

        if len(del_set) > 0:

            for i in del_set:
                try:
                    lock_dict.pop(i)
                except KeyError:
                    print('key error:', f_realpath)


def get_self_rss():
    """
    """
    with open('/proc/self/statm') as f:
        return int(f.readline().split(' ')[1]) * SC_PAGESIZE


def get_total_size():
    """
    """
    ts = 0
    for rp in lock_dict:
        s = lock_dict[rp][1]
        ts += s
    return ts


def mm_debug():
    """
    """
    locked = get_total_size()
    num = len(lock_dict)

    print('currently locked {}M, {} files'.format(
        round(locked / MIB, 1),
        num,
    ))


def string_to_int_convert_test(string):
    """Try to interpret string values as integers."""
    try:
        return int(string)
    except ValueError:
        return None


def get_final_set():
    """
    d, lock_path_set -> final_set
    """
    final_set = set()

    for rp in d:
        for min_entry, from_latest in lock_path_set:
            v_list = d[rp][-from_latest:]
            entry = v_list.count(YES)
            if entry >= min_entry:
                final_set.add(rp)
                continue

    return final_set


def rotate_snapshots():
    """
    current_set, minus, max_store_num, d
    -> d
    """
    for rp in current_set:
        if rp in d:
            v_list = d[rp][-max_store_num:]
            v_list.append(YES)
            if len(v_list) > max_store_num:
                del v_list[0]
            d[rp] = v_list
        else:
            d[rp] = [YES]

    for rp in minus:
        v_list = d[rp][-max_store_num:]
        v_list.append(NO)
        if len(v_list) > max_store_num:
            del v_list[0]
        if YES in v_list:
            d[rp] = v_list
        else:
            del d[rp]


def jdump(pathname, data):
    """
    """
    try:
        with open(pathname, 'w') as f:
            dump(data, f, sort_keys=True, indent=0)
    except (PermissionError, FileNotFoundError) as e:
        errprint(e)
        exit(1)


def jload(pathname):
    """
    """
    with open(pathname) as f:
        return load(f)


def pid_to_cgroup_v2(pid):
    """
    """
    cgroup_v2 = ''
    try:
        with open('/proc/' + pid + '/cgroup') as f:
            for index, line in enumerate(f):
                if index == cgroup_v2_index:
                    cgroup_v2 = line[3:-1]
        return cgroup_v2
    except FileNotFoundError:
        return ''


def get_cgroup2_index():
    """
    Find cgroup-line position in /proc/[pid]/cgroup file.
    """
    cgroup_v2_index = None

    with open('/proc/self/cgroup') as f:
        for index, line in enumerate(f):
            if line.startswith('0::'):
                cgroup_v2_index = index

    return cgroup_v2_index


def cpu():
    """
    """
    m = monotonic() - start_time
    user_time, system_time = times()[0:2]
    p_time = user_time + system_time
    p_percent = p_time / m * 100

    print('Uptime {}, CPU time {}s (user {}s, sys {}s), avg {}%'.format(
        format_time(m),
        round(p_time, 2),
        user_time,
        system_time,
        round(p_percent, 2)
    ))


def valid_v(x):
    """
    """
    if x == '0':
        return False
    elif x == '1':
        return True
    else:
        errprint('Invalid $VERBOSITY value')
        exit(1)


###############################################################################


start_time = monotonic()

MIB = 1024**2

self_pid = str(getpid())

SC_CLK_TCK = sysconf(sysconf_names['SC_CLK_TCK'])

uniq_map_d = dict()

min_save_lifetime = 300

dump_path = '/var/lib/prelockd/dump.json'

cgroup_v2_index = get_cgroup2_index()

a = argv[1:]
la = len(a)
if la == 0:
    errprint('invalid input: missing CLI options')
    exit(1)
elif la == 1:
    if a[0] == '-r':
        from os import remove
        try:
            remove(dump_path)
            print('OK')
            exit()
        except (FileNotFoundError, PermissionError) as e:
            errprint(e)
            exit(1)
    else:
        errprint('invalid input')
        exit(1)
elif la == 2:
    if a[0] == '-c':
        config = a[1]
    else:
        errprint('invalid input')
        exit(1)
else:
    errprint('invalid input: too many options')
    exit(1)


with open('/proc/meminfo') as f:
    mem_list = f.readlines()
    mem_total = int(mem_list[0].split(':')[1][:-4])

mem_list_names = []

for s in mem_list:
    mem_list_names.append(s.split(':')[0])

SC_PAGESIZE = sysconf(sysconf_names['SC_PAGESIZE'])

config_dict = dict()

lock_path_set = set()

max_store_num = 0

name_set = set()
cgroup_set = set()

try:

    with open(config) as f:

        for line in f:

            if line[0] == '$' and '=' in line:
                key, _, value = line.partition('=')
                key = key.rstrip()
                value = value.strip()
                if key in config_dict:
                    errprint('config key {} duplication'.format(key))
                    exit(1)
                config_dict[key] = value

            if line[0] == '@':

                if line.startswith('@LOCK_PATH ') and '=' in line:
                    a_list = line.partition('@LOCK_PATH ')[2:][0].split()
                    lal = len(a_list)
                    if lal != 2:
                        print(lal)
                        errprint('invalid conf')
                        exit(1)

                    a_dict = dict()

                    for pair in a_list:
                        key, _, value = pair.partition('=')
                        a_dict[key] = value

                    min_entry = string_to_int_convert_test(a_dict['MIN_ENTRY'])
                    if min_entry is None:
                        errprint('Invalid config: invalid MIN_ENTRY: not int')

                    from_latest = string_to_int_convert_test(
                        a_dict['FROM_LATEST'])
                    if from_latest is None:
                        errprint(
                            'Invalid config: invalid FROM_LATEST: not int')

                    if min_entry > from_latest:
                        errprint('invalid conf')
                        exit(1)

                    if min_entry < 1 or from_latest < 1:
                        errprint('invalid conf')
                        exit(1)

                    if from_latest > max_store_num:
                        max_store_num = from_latest

                    lock_path_set.add((min_entry, from_latest))

                if line.startswith('@CRITICAL_NAME_LIST '):
                    a_list = line.partition(
                        '@CRITICAL_NAME_LIST ')[2:][0].split(',')

                    for name in a_list:
                        name_set.add(name.strip(' \n'))

                if line.startswith('@CRITICAL_CGROUP2_REGEX '):
                    cgroup_re = line.partition('@CRITICAL_CGROUP2_REGEX ')[2:][
                        0].strip(' \n')
                    if valid_re(cgroup_re) is None:
                        cgroup_set.add(cgroup_re)

except (PermissionError, UnicodeDecodeError, IsADirectoryError,
        IndexError, FileNotFoundError) as e:
    errprint('Invalid config: {}. Exit.'.format(e))
    exit(1)


name_set.discard('')


check_cgroup = bool(len(cgroup_set))


if '$VERBOSITY' in config_dict:
    verbosity = config_dict['$VERBOSITY']

    if len(verbosity) != 4:
        errprint('invalid $VERBOSITY value')
        exit(1)

    debug_1 = valid_v(verbosity[0])
    debug_self = valid_v(verbosity[1])
    debug_lock = valid_v(verbosity[2])
    debug_map = valid_v(verbosity[3])

else:
    errprint('missing $VERBOSITY key')
    exit(1)


if '$LOCK_ONLY_CRITICAL' in config_dict:
    lock_only_critical = config_dict['$LOCK_ONLY_CRITICAL']
    if lock_only_critical == 'True':
        lock_only_critical = True
    elif lock_only_critical == 'False':
        lock_only_critical = False
    else:
        errprint('invalid $LOCK_ONLY_CRITICAL value')
        exit(1)
else:
    errprint('missing $LOCK_ONLY_CRITICAL key')
    exit(1)


if '$MAX_FILE_SIZE_MIB' in config_dict:
    string = config_dict['$MAX_FILE_SIZE_MIB']
    max_file_size_mib = string_to_float_convert_test(string)
    if max_file_size_mib is None:
        errprint('invalid $MAX_FILE_SIZE_MIB value')
        exit(1)
    max_file_size = int(max_file_size_mib * MIB)
else:
    errprint('missing $MAX_FILE_SIZE_MIB key')
    exit(1)


if '$LOCK_PATH_REGEX' in config_dict:
    lock_path_regex = config_dict['$LOCK_PATH_REGEX']
    valid_re(lock_path_regex)
else:
    errprint('missing $LOCK_PATH_REGEX key')
    exit(1)


if '$MAX_TOTAL_SIZE_PERCENT' in config_dict:
    string = config_dict['$MAX_TOTAL_SIZE_PERCENT']
    max_total_size_percent = string_to_float_convert_test(string)
    if max_total_size_percent is None:
        errprint('invalid $MAX_TOTAL_SIZE_PERCENT value')
        exit(1)
    max_total_size_pc = int(
        mem_total * max_total_size_percent / 100) * 1024  # bytes
else:
    errprint('missing $MAX_TOTAL_SIZE_PERCENT key')
    exit(1)


if '$MAX_TOTAL_SIZE_MIB' in config_dict:
    string = config_dict['$MAX_TOTAL_SIZE_MIB']
    max_total_size_mib = string_to_float_convert_test(string)
    if max_total_size_mib is None:
        errprint('invalid $MAX_TOTAL_SIZE_MIB value')
        exit(1)
    max_total_size_mib = int(max_total_size_mib * MIB)  # bytes
else:
    errprint('missing $MAX_TOTAL_SIZE_MIB key')
    exit(1)


if max_total_size_mib <= max_total_size_pc:
    max_total_size = max_total_size_mib
else:
    max_total_size = max_total_size_pc


if '$POLL_INTERVAL_SEC' in config_dict:
    string = config_dict['$POLL_INTERVAL_SEC']
    interval = string_to_float_convert_test(string)
    if interval is None:
        errprint('invalid $POLL_INTERVAL_SEC value')
        exit(1)
else:
    errprint('missing $POLL_INTERVAL_SEC key')
    exit(1)


config = path.abspath(config)

print('Starting prelockd with config {}'.format(config))

if debug_self:
    print('$LOCK_PATH_REGEX:            ', lock_path_regex)
    print('$MAX_FILE_SIZE_MIB:          ', max_file_size / MIB)
    print('$MAX_TOTAL_SIZE_MIB:         ', round(max_total_size_mib / MIB, 1))
    print('$MAX_TOTAL_SIZE_PERCENT:     ', round(
        max_total_size_pc / MIB, 1), '(MiB)')
    print('max_total_size:              ', round(
        max_total_size / MIB, 1), '(MiB)')
    print('$VERBOSITY:                  ', verbosity)
    print('$LOCK_PATH_REGEX:            ', lock_path_regex)
    print('@LOCK_PATH                   ', lock_path_set)
    print('$LOCK_ONLY_CRITICAL          ', lock_only_critical)
    print('@CRITICAL_CGROUP2_REGEX set: ', cgroup_set)
    print('@CRITICAL_NAME_LIST          ', list(name_set))


if max_store_num == 0:
    print('WARNING: lock rules are empty!')


mlockall()

dump_d = dict()
var_dict = dict()
lock_dict = dict()

sig_list = [SIGTERM, SIGINT, SIGQUIT, SIGHUP]

for i in sig_list:
    signal(i, signal_handler)

if debug_self:
    mm_debug()

YES = 1
NO = 0


if debug_self:
    self_rss = get_self_rss()
    print('self rss: {}M'.format(round(self_rss / MIB, 1)))

var_dict['lock_t0'] = monotonic()

try:
    dump_d = jload(dump_path)
    d = dump_d['d']
    lock_t = dump_d['t']
    final_set = get_final_set()
    lock_files(final_set)
    lock_t0 = monotonic() - lock_t
    var_dict['lock_t0'] = lock_t0
    extra_t = interval - lock_t
    if extra_t > 0:
        stdout.flush()
        sleep(extra_t)
except FileNotFoundError as e:
    if debug_self:
        print(e)
    d = dict()
    lock_t = 0
except Exception as e:
    print(e)
    d = dict()
    lock_t = 0

while True:
    current_set = get_current_set()
    len_cur = len(current_set)
    d_set = set(d)
    minus = d_set - current_set
    rotate_snapshots()
    final_set = get_final_set()
    old_final_set = set(lock_dict)
    unlock_it = old_final_set - final_set
    unlock_files(unlock_it)
    lock_it = final_set - old_final_set
    lock_files(lock_it)
    var_dict['lock_t0'] = monotonic()
    if debug_self:
        self_rss = get_self_rss()
        print('self rss: {}M'.format(round(self_rss / MIB, 1)))
        cpu()
    stdout.flush()
    sleep(interval)
