import pytest
import torch


def test_multiheaded_hebbian_attention_layer():
    """Test MultiHeadedHebbianAttentionLayer"""
    from sparks.models.blocks import MultiHeadedHebbianAttentionLayer
    
    n_neurons = 100
    embed_dim = 64
    n_heads = 4
    batch_size = 32
    
    layer = MultiHeadedHebbianAttentionLayer(
        n_neurons=n_neurons,
        embed_dim=embed_dim,
        n_heads=n_heads,
        tau_s=1.0,
        dt=0.001,
        data_type='ephys'
    )
    
    # Test that embed_dim is divisible by n_heads
    with pytest.raises(ValueError):
        MultiHeadedHebbianAttentionLayer(
            n_neurons=n_neurons,
            embed_dim=63,  # Not divisible by 4
            n_heads=4
        )
    
    # Test forward pass
    spikes = torch.randn(batch_size, n_neurons)
    output = layer(spikes)
    
    assert output.shape == (batch_size, n_neurons, embed_dim)
    
    # Test detach and zero
    layer.detach_()
    layer.zero_()


def test_hebbian_attention_block():
    """Test HebbianAttentionBlock"""
    from sparks.models.blocks import HebbianAttentionBlock
    
    n_neurons = 100
    embed_dim = 64
    n_heads = 2
    batch_size = 16
    
    block = HebbianAttentionBlock(
        n_neurons=n_neurons,
        embed_dim=embed_dim,
        n_heads=n_heads,
        data_type='ephys'
    )
    
    # Test forward pass
    spikes = torch.randn(batch_size, n_neurons)
    output = block(spikes)
    
    assert output.shape == (batch_size, n_neurons, embed_dim)
    
    # Test methods
    block.detach_()
    block.zero_()


def test_attention_block():
    """Test conventional AttentionBlock"""
    from sparks.models.blocks import AttentionBlock
    
    embed_dim = 64
    n_heads = 4
    batch_size = 16
    seq_len = 100
    
    block = AttentionBlock(embed_dim=embed_dim, n_heads=n_heads)
    
    # Test forward pass
    x = torch.randn(batch_size, seq_len, embed_dim)
    output = block(x)
    
    assert output.shape == (batch_size, seq_len, embed_dim)
    
    # Test compatibility methods
    block.zero_()
    block.detach_()