MIT License

Copyright (c) 2025 Ahsan Shaokat

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files...


---

## 🧩 4. `losses.py` (for your repo)

Just to have it clearly in one place again:

```python
"""
Loss functions for training Mizan-based embedding models.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F


def mizan_similarity_torch(
    v1: torch.Tensor,
    v2: torch.Tensor,
    p: float = 2.0,
    eps: float = 1e-8,
) -> torch.Tensor:
    diff = torch.norm(v1 - v2, p=2, dim=-1)
    norm1 = torch.norm(v1, p=2, dim=-1)
    norm2 = torch.norm(v2, p=2, dim=-1)
    num = diff**p
    den = norm1**p + norm2**p + eps
    return 1.0 - num / den


class MizanContrastiveLoss(nn.Module):
    def __init__(self, margin: float = 0.5, p: float = 2.0, eps: float = 1e-8):
        super().__init__()
        self.margin = margin
        self.p = p
        self.eps = eps

    def forward(
        self,
        emb1: torch.Tensor,
        emb2: torch.Tensor,
        labels: torch.Tensor,
    ) -> torch.Tensor:
        sim = mizan_similarity_torch(emb1, emb2, p=self.p, eps=self.eps)
        labels = labels.float()
        pos_loss = labels * (1.0 - sim)
        neg_loss = (1.0 - labels) * F.relu(sim - self.margin)
        return (pos_loss + neg_loss).mean()


class MizanTripletLoss(nn.Module):
    def __init__(self, margin: float = 0.3, p: float = 2.0):
        super().__init__()
        self.margin = margin
        self.p = p

    def forward(self, anchor: torch.Tensor, positive: torch.Tensor, negative: torch.Tensor) -> torch.Tensor:
        sim_pos = mizan_similarity_torch(anchor, positive, p=self.p)
        sim_neg = mizan_similarity_torch(anchor, negative, p=self.p)
        loss = F.relu(self.margin - (sim_pos - sim_neg))
        return loss.mean()
