.. _build:

*DIRESA* model
==============

**3. Build DIRESA model**

We can build a *DIRESA* model with convolutional and/or dense layers with the *build_diresa* function. 
We can also build a *DIRESA* model based on a custom encoder and decoder with the *diresa_model* function (see below). 
We build here a model with an input shape of *(3,)* for the 3D butterfly points. 
Our encoder model has 3 dense layers with 40, 20 and 2 units (the latter is the dimension of the latent space). 
The decoder is a reflection of the encoder. The DIRESA model has 3 loss functions, 
the reconstruction loss (usually the MSE is used here), the covariance loss and a distance loss 
(here the MSE distance loss is used). Also the weights for the diffenent loss functions are specified.

.. code-block:: ipython
  
  from diresa.models import build_diresa
  from diresa.loss import mse_dist_loss, LatentCovLoss

  diresa = build_diresa(input_shape=(3,), dense_units=(40, 20, 2))

  diresa.compile(loss=['MSE', LatentCovLoss(1.), mse_dist_loss], loss_weights=[1., 3., 1.5])

In order to lower the loss weight tuning effort, we will use annealing for the covariance loss. In this case, 
the covariance weight starts from an initial value (here the keras backend variable *cov_weight* is initialized to 0.) 
and is increased until the covariance loss reaches a certain target.

.. code-block:: ipython
  
  import keras.backend as K
  from diresa.callback import LossWeightAnnealing

  cov_weight = K.variable(0.)
  diresa.compile(loss=['MSE', LatentCovLoss(cov_weight), mse_dist_loss], loss_weights=[1., 1., 1.5])
  diresa.summary(expand_nested=True)
  
**4. Train the DIRESA model**

We train the *DIRESA* model in a standard way. The output of the decoder should fit the input of the encoder. 
The batch size should be large enough for the calculation of the covariance loss, which calculates 
the covariance matrix of the latent space components over the batch.
In the *LossWeightAnnealing* callback, we specify the target (*target_loss*) for the mean squared covariance 
between the latent components. Also the step size by which the annealing weight factor is increased (*anneal_step*) 
and epoch from which annealing is started (*start_epoch*) is specified. If annealing is not used, 
the fit method is called without callback function.

.. code-block:: ipython
  
  callback = [LossWeightAnnealing(cov_weight, loss_name="val_Latent_loss", target_loss=0.0001, anneal_step=0.2, start_epoch=3)]
  diresa.fit((train, train_twin), train, 
             validation_data=((val, val_twin), val),
             epochs=20, batch_size=512, shuffle=True, verbose=2, callbacks=callback)
  
**5. Encoder and decoder submodel**

We cut out the encoder and decoder submodels with the cut_sub_model function. 
So we can make predictions for latent and decoded space.

.. code-block:: ipython
  
  from diresa.toolbox import cut_sub_model
  compress_model = cut_sub_model(diresa, 'Encoder')
  decode_model = cut_sub_model(diresa, 'Decoder')
  latent = compress_model.predict(train)
  predict = decode_model.predict(latent)