"""Unit tests for adult neural network"""

import unittest
import pandas as pd
import numpy as np
import torch

from synthetic_aia_mia.predictor import utk

class TestUtkNN(unittest.TestCase):
    """Test for the adult module of predictor package."""

    def test_predict_not_trained(self):
        """Test if an exception is raised if predict is called but the model is not trained."""
        df = pd.DataFrame([])
        clf = adult.AdultNN()
        with self.assertRaises(AssertionError) as cm:
            clf.predict(df)

    def test_neural_network(self):
        """Test if the pytroch model trains and predicts."""
        x = np.array([[1,0],[2,0],[3,1]])
        df = pd.DataFrame(x, columns=["x","PINCP"])
        config = {"l1":2,"l2":2,"lr":0.001,"batch_size":1}
        clf = adult._train(config,df,stand_alone=True)
        x = torch.tensor([[1],[2],[3]],dtype=torch.float)
        y = clf(x)
        self.assertEqual(len(y),3)

    def test_predict_trained(self):
        """Test if a trained model in the hyperparameter optimization interface can make a prediction."""
        N = 1000
        x = np.random.uniform(0,1,[N,1])
        y = np.random.randint(0,2,[N,1])
        data = np.hstack([x,y])
        df = pd.DataFrame(data, columns=["x","PINCP"])

        clf = adult.AdultNN()
        clf.fit(df)
        pred = clf.predict(df.drop("PINCP", axis=1))

        self.assertTrue("PINCP" in pred)
        self.assertEqual(len(pred["PINCP"]),N)
