#!/usr/bin/env python


from matplotlib import pyplot as plt


import argparse
from pNbody import *
from pNbody import units
import string 

from scipy import optimize

import numpy as np
from pNbody import pychem

description=""
epilog     ="""
Examples:
--------
imf_sample_IMF
imf_sample_IMF --M0 1e5
"""

parser = argparse.ArgumentParser(description=description,epilog=epilog,formatter_class=argparse.RawDescriptionHelpFormatter)

parser.add_argument("--M0",
                    action="store", 
                    dest="M0", 
                    metavar='FLOAT', 
                    type=float,                    
                    default=1e5,
                    help='IMF total mass in Msol')     

########################################################################
#             M A I N
########################################################################

opt = parser.parse_args()

# some parameters
M0     = opt.M0

# init

params = {}
params["Mmax"] = 50.
params["Mmin"] = 0.05
params["as"] = [0.7,-0.8,-1.7,-1.3]
params["ms"] = [0.08,0.5,1.0]


pychem.set_parameters(params)
mmax = pychem.get_Mmax()
mmin = pychem.get_Mmin()

# compute the number of stars per mass between m1 and m2 (dep on M0)
# N this is thus the number of stars in a particle of mass M0
N    = pychem.get_imf_N(np.array([mmin]),np.array([mmax]))*M0        
# compute the masses
ms   = pychem.imf_sampling(int(N),1)


# in Msol
print("number of stars",N)
print("total mass     ",sum(ms))
#print min(ms)*UnitMass_in_g/SOLAR_MASS
#print max(ms)*UnitMass_in_g/SOLAR_MASS



##############################################
# plot the imf
##############################################


n = 10000.
db = (mmax-mmin)/n

bins = np.arange(mmin,mmax+db,db)

# histogram  
n = np.searchsorted(np.sort(ms),bins)
n = np.concatenate([n,[len(ms)]])
hx = n[1:]-n[:-1]
hx = hx*bins

print("total mass (histogram)",sum(hx))


plt.plot(bins,hx)



plt.loglog()
plt.xlabel(r'$\rm{Mass}\,[\rm{M}_\odot]$')
plt.ylabel(r'$\rm{number\,of\,stars}\,,\,\rm{mass\,fraction}$')

plt.show()






