#!/Users/phillipwang/.pyenv/versions/3.6.7/bin/python

from stylegan2_pytorch import Trainer

def train_from_folder(
    data,
    name = 'default',
    new = False,
    load_from = -1,
    image_size = 128,
    batch_size = 3,
    gradient_accumulate_every = 5,
    network_capacity = 16,
    num_train_steps = 100000
):
    model = Trainer(
        name,
        data,
        batch_size = batch_size,
        gradient_accumulate_every = gradient_accumulate_every,
        image_size = image_size,
        network_capacity = network_capacity
    )

    if not new:
        model.load(load_from)
    else:
        model.clear()

    for _ in tqdm(range(num_train_steps - model.steps), mininterval=10.):
        model.train()
        if _ % 50 == 0:
            model.print_log()

if __name__ == "__main__":
    fire.Fire(train_from_folder)
