#!/usr/bin/env python

import os,sys
import unittest
from pixell.tests import *

unittest.main(exit=False)


def run_alm_benchmark(nthreads):
    os.system(f"""OMP_NUM_THREADS={nthreads} python -c 'from pixell import curvedsky,enmap,utils ;
import time ;
import numpy as np ;

np.random.seed(100) ;
shape,wcs = enmap.fullsky_geometry(res=12.*utils.arcmin) ;
imap = enmap.enmap(np.random.random(shape),wcs) ;

nsims = 40 ; 
lmax = int(6000 * (2./16.)); 

t0 = time.time() ;
for i in range(nsims): alm = curvedsky.map2alm(imap,lmax=lmax) ; omap = curvedsky.alm2map(alm,enmap.empty(shape,wcs)) ;
t1 = time.time();

total = t1-t0;
print(f"{{total:.4f}} seconds.");
'
    """)

import multiprocessing
max_threads = multiprocessing.cpu_count()
assert max_threads>=1

print("Single threaded alm test:")
run_alm_benchmark(1)

if max_threads==1:
    print("Multi-threading not detected.")
else:
    print(f"Multi-threaded alm test with {max_threads} threads:")
    run_alm_benchmark(max_threads)
    

