#!python
from khmernormalizer import normalize
from khmercut import tokenize

def _processor(line):
    line = line.rstrip()    
    tokens = tokenize(line)            
    return tokens

def _processor_norm(line):
    line = line.rstrip()    
    tokens = tokenize(normalize(line, emoji_replacement="", url_replacement="", remove_zwsp=True))            
    return tokens

if __name__ == "__main__":
    from pathlib import Path
    from argparse import ArgumentParser, RawTextHelpFormatter
    from multiprocessing import Pool
    from tqdm.auto import tqdm

    parser = ArgumentParser(
        "khmercut", description="A fast Khmer word segmentation toolkit.",
        formatter_class=RawTextHelpFormatter
    )
    parser.add_argument(
        "files", nargs="+", type=Path, default=[], help="Path to text files"
    )
    parser.add_argument(
        "-d",
        "--directory",
        type=Path,
        default=Path("khmercut_output"),
        help="Output folder",
    )

    parser.add_argument(
        "-s", "--separator", type=str, default=" ", help="Specify token separator"
    )

    parser.add_argument(
        "-j", "--jobs", type=int, default=1, help="Number of processors"
    )
    
    parser.add_argument(
        "-q", "--quiet", action="store_true", default=False, help="Disable progress output"
    )

    parser.add_argument(
        "-n",
        "--normalize",
        action="store_true",
        default=False,
        help="Normalize input text before processing",
    )

    args = parser.parse_args()
    
    assert args.jobs > 0, "jobs must be greater than 0"

    os.makedirs(args.directory, exist_ok=True)

    total = len(args.files)
    pool = Pool(args.jobs)
    
    for file in args.files:
        filename = os.path.basename(file)
        output_file = os.path.join(args.directory, filename)
        num_lines = sum(1 for _ in open(file))
        
        with open(output_file, "w") as outf:
            with tqdm(total=num_lines, desc=str(file), disable=args.quiet) as pbar:
                if args.normalize:
                    for tokens in pool.imap(_processor, open(file)):
                        outf.write(args.separator.join(tokens) + "\n")
                        pbar.update()
                else:
                    for tokens in pool.imap(_processor_norm, open(file)):
                        outf.write(args.separator.join(tokens) + "\n")
                        pbar.update()

    pool.close()
    pool.join()
