#!/usr/bin/env python
# Hacked together by Ali Hassani (@alihassanijr)
# Just a better way to watch SMI

# Requires nvidia-smi
# Requires `rich` python package

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from time import sleep
from datetime import datetime
import csv
from subprocess import run, PIPE
from rich.console import Console
from rich.table import Table
from rich.align import Align
from rich.live import Live
import argparse
import os
import pwd


VER = '0.1'
UPDATED = '03-08-2022'
MIN_TEMP = 50
MAX_TEMP = 100

EXCLUDED_USERS = ['root']


def n_gpus():
    if os.path.isdir('/proc/driver/nvidia/gpus'):
        return len(os.listdir('/proc/driver/nvidia/gpus'))
    return 0


def safe_cast(var, t, default='None'):
    try:
        return t(var)
    except (ValueError, TypeError):
        return default


def frac_to_color(frac):
    if frac <= 0.2:
        return 'green1'
    if frac <= 0.3:
        return 'green3'
    if frac <= 0.4:
        return 'gold1'
    if frac <= 0.5:
        return 'orange1'
    if frac <= 0.6:
        return 'dark_orange'
    if frac <= 0.7:
        return 'orange_red1'
    if frac <= 0.8:
        return 'deep_pink2'
    if frac <= 0.9:
        return 'red1'
    return 'bright_red'


class GPU:
    def __init__(self, index, gpu_dict):
        self.index = safe_cast(index, int, -1)
        self.defective = False
        self.model = gpu_dict['model'].replace('NVIDIA', '').strip()
        self.bus_id = gpu_dict['bus_id']
        self.util = safe_cast(gpu_dict['util'], int, -1)
        self.memory = safe_cast(gpu_dict['memory'], int, -1)
        self.memory_used = safe_cast(gpu_dict['used_memory'], int, -1)
        self.fan = safe_cast(gpu_dict['fan'], int, -1)
        self.has_fan = 'N/A' not in gpu_dict['fan']
        self.temp = safe_cast(gpu_dict['temp'], int, -1)
        self.power = int(safe_cast(gpu_dict['power'], float, -1))
        self.power_usage = int(safe_cast(gpu_dict['power_usage'], float, -1))
        self.free = gpu_dict['free']
        self.owned = False

    @property
    def index_str(self):
        if self.index < 0:
            return 'E'
        return str(self.index)

    @property
    def memory_str(self):
        if self.memory < 0:
            return ' ERR '
        return f'{self.memory} MB'

    @property
    def memory_used_str(self):
        if self.memory_used < 0:
            return ' ERR '
        return f'{str(self.memory_used).rjust(5)} MB'

    @property
    def fan_str(self):
        if self.fan < 0:
            if not self.has_fan:
                return ' N/A '
            return ' ERR '
        return f'{str(self.fan).rjust(3)} %'

    @property
    def temp_str(self):
        if self.temp < 0:
            return 'ERR'
        return f'{str(self.temp).rjust(3)} C'

    @property
    def util_str(self):
        if self.util < 0:
            return 'ERR'
        return f'{str(self.util).rjust(3)} %'

    @property
    def power_usage_str(self):
        if self.power_usage < 0:
            return 'ERR'
        return f'{str(self.power_usage).rjust(3)} W'

    @property
    def power_str(self):
        if self.power < 0:
            return 'ERR'
        return f'{str(self.power).rjust(3)} W'

    @property
    def memory_util(self):
        if self.memory_used < 0 or self.memory < 0:
            return 10
        return self.memory_used / self.memory

    @property
    def gpu_util(self):
        if self.util < 0:
            return 10
        return self.util / 100

    @property
    def fan_speed(self):
        if self.fan < 0:
            if not self.has_fan:
                return 0
            return 10
        return self.fan / 100

    @property
    def temp_frac(self):
        if self.temp < 0:
            return 10
        return (self.temp - MIN_TEMP) / (MAX_TEMP - MIN_TEMP)

    @property
    def power_util(self):
        if self.power_usage < 0 or self.power < 0:
            return 10
        return self.power_usage / self.power

    def get_table_info(self):
        return {
                   '#': self.index_str,
                   'GPU': self.model,
                   'Memory': f'{self.memory_used_str} / {self.memory_str}',
                   'Util': self.util_str,
                   'Fan': self.fan_str,
                   'Temp': self.temp_str,
                   'Power': f'{self.power_usage_str} / {self.power_str}'
               }, {
                   'STYLE': [
                       'default' if self.free else ('red' if not self.owned else 'dim'),
                       'default' if self.free else ('red' if not self.owned else 'dim'),
                       frac_to_color(self.memory_util),
                       frac_to_color(self.gpu_util),
                       frac_to_color(self.fan_speed),
                       frac_to_color(self.temp_frac),
                       frac_to_color(self.power_util)
                   ]
               }


class DefectiveGPU:
    def __init__(self, index, gpu_dict):
        self.index = safe_cast(index, int, -1)
        self.defective = True
        self.model = gpu_dict['model'].replace('NVIDIA', '').strip()
        self.memory = 0

    @property
    def index_str(self):
        if self.index < 0:
            return 'E'
        return str(self.index)

    def get_table_info(self):
        return {
                   '#': self.index_str,
                   'GPU': self.model,
                   'Memory': 'DEFECTIVE',
                   'Util': '-',
                   'Fan': '-',
                   'Temp': '-',
                   'Power': '-   /   -'
               }, {
                   'STYLE': ['red' for _ in range(7)]
               }


class DeadGPU(DefectiveGPU):
    def __init__(self, index):
        super(DeadGPU, self).__init__(index, {'model': 'DEAD'})


class ComputeProcess:
    def __init__(self, gpu_inst, process_dict):
        self.gpu = gpu_inst.index
        self.pid = process_dict['pid']
        self.user, self.uid = get_user(self.pid)
        self.name = process_dict['name']
        self.memory_used = int(process_dict['used_memory'])
        self.memory = gpu_inst.memory
        self.owned = os.getuid() == self.uid

    @property
    def memory_str(self):
        return f'{self.memory} MB'

    @property
    def memory_used_str(self):
        return f'{str(self.memory_used).rjust(5)} MB'

    @property
    def memory_util(self):
        if self.memory is not float or self.memory < 10:
            return 10
        return self.memory_used / self.memory

    def get_table_info(self):
        return {
                   'GPU': f'{self.gpu}',
                   'PID': f'{self.pid}',
                   'User': self.user,
                   'Process': self.name[:min(25, len(self.name))],
                   'Memory': self.memory_used_str,
               }, {
                   'STYLE': [
                       'default' if self.owned else 'dim',
                       'default' if self.owned else 'dim',
                       'default' if self.owned else 'dim',
                       'default' if self.owned else 'dim',
                       frac_to_color(self.memory_util),
                   ]
               }


def list2dict(obj_list):
    cols = []
    for obj in obj_list:
        for k in obj.get_table_info()[0].keys():
            None if k in cols else cols.append(k)
    rows, meta = [], []
    for obj in obj_list:
        inf, m = obj.get_table_info()
        rows.append(['-' if k not in inf else str(inf[k]) for k in cols])
        meta.append(m)
    return {'cols': cols, 'rows': rows, 'meta': meta}


def run_command(command):
    result = run([command], shell=True, stdout=PIPE)
    out = str(result.stdout.decode("utf-8")).split('\n')
    return out


def get_user(pid):
    if os.path.isdir(f"/proc/{pid}"):
        proc_stat_file = os.stat(f"/proc/{pid}")
        uid = proc_stat_file.st_uid
        return pwd.getpwuid(uid)[0], uid
    return '-', 0


def parse_csv(command):
    output = run_command(command)
    output = [o for o in output if ',' in o]
    if len(output) < 1:
        return None
    reader = csv.reader(output)
    return list(reader)


def parse_smi(n_devices, i=None):
    warn = False
    gpu_headers = ['bus_id', 'index', 'model', 'util', 'memory', 'used_memory', 'fan', 'temp', 'power', 'power_usage']
    smi_headers = ['pci.bus_id', 'index', 'name', 'utilization.gpu', 'memory.total', 'memory.used', 'fan.speed',
                   'temperature.gpu', 'power.limit', 'power.draw']
    p_headers = ['gpu_bus_id', 'pid', 'used_memory', 'name']
    appendage = None if i is None else f'-i {i}'
    gpus = parse_csv(f'nvidia-smi --query-gpu={",".join(smi_headers)} --format=csv,noheader,nounits {appendage}')
    if gpus is None and i is not None:
        return [DeadGPU(i)], [], True
    elif gpus is None:
        o = [parse_smi(n_devices, i=i) for i in range(n_devices)]
        g, p, w = [], [], False
        for i in range(len(o)):
            g.extend(o[i][0])
            p.extend(o[i][1])
            w = w and o[i][2]
        return g, p, w
    processes = parse_csv(f'nvidia-smi --query-compute-apps={",".join(p_headers)} --format=csv,noheader,nounits {appendage}')
    out_gpus = []
    out_processes = []
    bus2i = {}
    for gi, g in enumerate(gpus):
        if g is None or len(g) < len(gpu_headers):
            out_gpus.append(DeadGPU(gi))
            warn = True
            continue
        bus2i[g[0]] = gi
        gdict = {k: g[j].strip() for j, k in enumerate(gpu_headers)}
        gdict['free'] = True
        out_gpus.append(GPU(gdict['index'], gdict))
    if processes is not None:
        for p in processes:
            if p is not None and len(p) == len(p_headers):
                gpid = bus2i[p[0]]
                proc = ComputeProcess(out_gpus[gpid], {k: v.strip() for k, v in zip(p_headers, p)})
                out_processes.append(proc)
                out_gpus[gpid].free = False
                out_gpus[gpid].owned = proc.owned or out_gpus[gpid].owned
    return out_gpus, out_processes, warn


def generate_table(dict_table, title=''):
    table = Table(show_header=bool(title), header_style="bold", title=title)
    cols, rows, meta = dict_table['cols'], dict_table['rows'], dict_table['meta']
    for i, k in enumerate(cols):
        table.add_column(k, justify="center")
    for r, m in zip(rows, meta):
        table.add_row(*[f'[{s}]{c}' for c, s in zip(r, m['STYLE'])])
    return table


def run_str(command):
    return ''.join(run_command(command))


def get_info(compact=False):
    hostname = run_str('hostname').strip()
    if compact:
        dt = datetime.now().strftime("%y-%m-%d %H:%M:%S")
        return f'{VER} \t {hostname} \t {dt}'
    dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    return f'fancy-smi {VER} \t \t {hostname} \t \t {dt}'


def stylish_smi(n_devices):
    grid = Table(show_header=False, padding=(0, 0), show_edge=False, expand=True)
    grid.add_column(no_wrap=True)
    gpus, processes, warn = parse_smi(n_devices=n_devices)
    gpus, processes = list2dict(gpus), list2dict(processes)
    gpu_table, process_table = generate_table(gpus, get_info()), generate_table(processes, 'Processes')
    grid.add_row(Align(gpu_table, 'center'))
    grid.add_row(Align(process_table, 'center'))
    if warn:
        grid.add_row(Align('Warning: One or more GPUs were thrown off the bus. Reboot to recover.', 'center'))
    return grid


def main():
    parser = argparse.ArgumentParser(description='Stylish SMI')
    parser.add_argument('-r', '--refresh-rate', type=int, default=0,
                        help='refresh rate in seconds, runs once if <0 passed (default: 0)')
    args = parser.parse_args()
    console = Console()
    n_devices = n_gpus()
    if args.refresh_rate < 1:
        console.print(stylish_smi(n_devices=n_devices))
    else:
        console.clear()
        with Live(stylish_smi(n_devices=n_devices), refresh_per_second=1 / args.refresh_rate, console=console) as live:
            while True:
                live.update(stylish_smi(n_devices=n_devices))
                sleep(args.refresh_rate)


if __name__ == '__main__':
    main()
