| | """
|
| | Unit tests for functional API (bitlinear_python, greedy_ternary_decomposition, etc.)
|
| |
|
| | These tests are here to validate the correctness of the pure PyTorch reference implementations. Here are the following test cases:
|
| |
|
| | TestBitLinearPython (5 tests)
|
| | 1. test_shape_correctness - Verifies output dimensions for 3D inputs
|
| | 2. test_no_bias - Tests forward pass without bias term
|
| | 3. test_ternary_constraint - Validates ternary weight values {-1, 0, +1}
|
| | 4. test_gamma_scaling - Verifies gamma scaling is applied correctly
|
| | 5. test_numerical_correctness - Compares against manual torch computation
|
| |
|
| | TestGreedyTernaryDecomposition (4 tests)
|
| | 1. test_decomposition_shape - Checks output tensor shapes
|
| | 2. test_ternary_values - Ensures all decomposed weights are ternary
|
| | 3. test_reconstruction_error - Validates error decreases with more components
|
| | 4. test_single_component - Tests k=1 edge case
|
| |
|
| | TestMultiTernaryLinearPython (2 tests)
|
| | 1. test_shape_correctness - Verifies output shape
|
| | 2. test_equivalence_to_sum - Confirms equivalence to summing individual operations
|
| |
|
| | TestActivationQuant (2 tests)
|
| | 1. test_quantization_range - Validates quantization behavior and output
|
| | 2. test_absmax_scaling - Tests per-token absmax scaling
|
| |
|
| | TestFunctionalIntegration (3 tests)
|
| | 1. test_full_pipeline - End-to-end: decomposition → multi-ternary forward
|
| | 2. test_bitlinear_with_activation_quant - Combines activation quantization with bitlinear
|
| | 3. test_multi_ternary_end_to_end - Tests different k values with reconstruction validation
|
| | """
|
| |
|
| | import pytest
|
| | import torch
|
| | import torch.nn as nn
|
| |
|
| | from bitlinear.functional import (
|
| | bitlinear_python,
|
| | greedy_ternary_decomposition,
|
| | multi_ternary_linear_python,
|
| | activation_quant,
|
| | )
|
| |
|
| |
|
| | class TestBitLinearPython:
|
| | """Tests for bitlinear_python function."""
|
| |
|
| | def test_shape_correctness(self):
|
| | """Test that output shape matches expected dimensions."""
|
| | batch_size, seq_len, in_features, out_features = 32, 128, 512, 1024
|
| | x = torch.randn(batch_size, seq_len, in_features)
|
| | W_ternary = torch.randint(-1, 2, (out_features, in_features)).float()
|
| | gamma = torch.ones(out_features)
|
| | bias = torch.zeros(out_features)
|
| |
|
| | output = bitlinear_python(x, W_ternary, gamma, bias)
|
| |
|
| | assert output.shape == (batch_size, seq_len, out_features)
|
| |
|
| | def test_no_bias(self):
|
| | """Test forward pass without bias."""
|
| | batch_size, in_features, out_features = 16, 256, 512
|
| | x = torch.randn(batch_size, in_features)
|
| | W_ternary = torch.randint(-1, 2, (out_features, in_features)).float()
|
| | gamma = torch.ones(out_features)
|
| |
|
| | output = bitlinear_python(x, W_ternary, gamma, bias=None)
|
| |
|
| | assert output.shape == (batch_size, out_features)
|
| | assert not torch.isnan(output).any()
|
| |
|
| | def test_ternary_constraint(self):
|
| | """Test that function works correctly with ternary weights {-1, 0, +1}."""
|
| | x = torch.randn(8, 64)
|
| | W_ternary = torch.randint(-1, 2, (128, 64)).float()
|
| | gamma = torch.ones(128)
|
| |
|
| |
|
| | unique_values = torch.unique(W_ternary)
|
| | assert all(v in [-1.0, 0.0, 1.0] for v in unique_values.tolist())
|
| |
|
| |
|
| | output = bitlinear_python(x, W_ternary, gamma)
|
| | assert output.shape == (8, 128)
|
| | assert not torch.isnan(output).any()
|
| |
|
| | def test_gamma_scaling(self):
|
| | """Test that gamma scaling is applied correctly."""
|
| | x = torch.randn(4, 32)
|
| | W_ternary = torch.randint(-1, 2, (64, 32)).float()
|
| | gamma = torch.rand(64) * 2 + 0.5
|
| |
|
| |
|
| | output_with_gamma = bitlinear_python(x, W_ternary, gamma, bias=None)
|
| |
|
| |
|
| | gamma_ones = torch.ones_like(gamma)
|
| | output_no_gamma = bitlinear_python(x, W_ternary, gamma_ones, bias=None)
|
| | output_manual_scale = output_no_gamma * gamma.unsqueeze(0)
|
| |
|
| |
|
| | assert torch.allclose(output_with_gamma, output_manual_scale, atol=1e-5)
|
| |
|
| | def test_numerical_correctness(self):
|
| | """Test numerical correctness against standard nn.Linear."""
|
| | in_features, out_features = 128, 256
|
| | x = torch.randn(16, in_features)
|
| | W_ternary = torch.randint(-1, 2, (out_features, in_features)).float()
|
| | gamma = torch.ones(out_features)
|
| | bias = torch.randn(out_features)
|
| |
|
| |
|
| | output_bitlinear = bitlinear_python(x, W_ternary, gamma, bias)
|
| |
|
| |
|
| | output_manual = torch.matmul(x, W_ternary.t()) * gamma.unsqueeze(0) + bias
|
| |
|
| |
|
| | assert torch.allclose(output_bitlinear, output_manual, atol=1e-6)
|
| |
|
| |
|
| | class TestGreedyTernaryDecomposition:
|
| | """Tests for greedy_ternary_decomposition function."""
|
| |
|
| | def test_decomposition_shape(self):
|
| | """Test that decomposition returns correct shapes."""
|
| | W = torch.randn(512, 768)
|
| | k = 4
|
| | W_ternary, gammas = greedy_ternary_decomposition(W, k)
|
| |
|
| | assert W_ternary.shape == (k, 512, 768)
|
| | assert gammas.shape == (k, 512)
|
| |
|
| | def test_ternary_values(self):
|
| | """Test that decomposed weights are ternary."""
|
| | W = torch.randn(64, 128)
|
| | k = 2
|
| | W_ternary, gammas = greedy_ternary_decomposition(W, k)
|
| |
|
| |
|
| | unique_values = torch.unique(W_ternary)
|
| | assert all(v in [-1.0, 0.0, 1.0] for v in unique_values.tolist()), \
|
| | f"Found non-ternary values: {unique_values.tolist()}"
|
| |
|
| | def test_reconstruction_error(self):
|
| | """Test that reconstruction error decreases with more components."""
|
| | W = torch.randn(128, 256)
|
| | errors = []
|
| |
|
| | for k in [1, 2, 4, 8]:
|
| | W_ternary, gammas = greedy_ternary_decomposition(W, k)
|
| |
|
| |
|
| | reconstruction = torch.zeros_like(W)
|
| | for i in range(k):
|
| | reconstruction += gammas[i].unsqueeze(1) * W_ternary[i]
|
| |
|
| | error = torch.norm(W - reconstruction).item()
|
| | errors.append(error)
|
| |
|
| |
|
| | assert errors[0] > errors[1], f"Error not decreasing: {errors[0]} vs {errors[1]}"
|
| | assert errors[1] > errors[2], f"Error not decreasing: {errors[1]} vs {errors[2]}"
|
| | assert errors[2] > errors[3], f"Error not decreasing: {errors[2]} vs {errors[3]}"
|
| |
|
| | def test_single_component(self):
|
| | """Test k=1 case (single ternary quantization)."""
|
| | W = torch.randn(32, 64)
|
| | k = 1
|
| | W_ternary, gammas = greedy_ternary_decomposition(W, k)
|
| |
|
| | assert W_ternary.shape == (1, 32, 64)
|
| | assert gammas.shape == (1, 32)
|
| |
|
| |
|
| | unique_values = torch.unique(W_ternary)
|
| | assert all(v in [-1.0, 0.0, 1.0] for v in unique_values.tolist())
|
| |
|
| |
|
| | class TestMultiTernaryLinearPython:
|
| | """Tests for multi_ternary_linear_python function."""
|
| |
|
| | def test_shape_correctness(self):
|
| | """Test output shape for multi-ternary linear."""
|
| | batch_size, in_features, out_features = 16, 128, 256
|
| | k = 4
|
| |
|
| | x = torch.randn(batch_size, in_features)
|
| | W_ternary = torch.randint(-1, 2, (k, out_features, in_features)).float()
|
| | gammas = torch.rand(k, out_features)
|
| | bias = torch.randn(out_features)
|
| |
|
| | output = multi_ternary_linear_python(x, W_ternary, gammas, bias)
|
| |
|
| | assert output.shape == (batch_size, out_features)
|
| |
|
| | def test_equivalence_to_sum(self):
|
| | """Test that multi-ternary equals sum of individual ternary ops."""
|
| | batch_size, in_features, out_features = 8, 64, 128
|
| | k = 3
|
| |
|
| | x = torch.randn(batch_size, in_features)
|
| | W_ternary = torch.randint(-1, 2, (k, out_features, in_features)).float()
|
| | gammas = torch.rand(k, out_features)
|
| | bias = torch.randn(out_features)
|
| |
|
| |
|
| | output_multi = multi_ternary_linear_python(x, W_ternary, gammas, bias)
|
| |
|
| |
|
| | output_sum = torch.zeros(batch_size, out_features)
|
| | for i in range(k):
|
| | output_sum += bitlinear_python(x, W_ternary[i], gammas[i], bias=None)
|
| | output_sum += bias
|
| |
|
| |
|
| | assert torch.allclose(output_multi, output_sum, atol=1e-5)
|
| |
|
| |
|
| | class TestActivationQuant:
|
| | """Tests for activation quantization."""
|
| |
|
| | def test_quantization_range(self):
|
| | """Test that quantized activations are in expected range."""
|
| | x = torch.randn(16, 128, 512) * 10
|
| | bits = 8
|
| |
|
| | x_quant = activation_quant(x, bits=bits)
|
| |
|
| |
|
| | assert x_quant.shape == x.shape
|
| |
|
| |
|
| | assert not torch.allclose(x, x_quant, atol=1e-6)
|
| |
|
| |
|
| | assert torch.isfinite(x_quant).all()
|
| |
|
| | def test_absmax_scaling(self):
|
| | """Test that absmax scaling is applied correctly."""
|
| |
|
| | x = torch.tensor([
|
| | [1.0, 2.0, 3.0, 4.0],
|
| | [-5.0, -10.0, 5.0, 10.0],
|
| | ])
|
| |
|
| | x_quant = activation_quant(x, bits=8)
|
| |
|
| |
|
| |
|
| |
|
| | assert x_quant.shape == (2, 4)
|
| | assert torch.isfinite(x_quant).all()
|
| |
|
| |
|
| |
|
| | relative_error = torch.abs(x - x_quant) / (torch.abs(x) + 1e-5)
|
| | assert relative_error.mean() < 0.1
|
| |
|
| |
|
| |
|
| | class TestFunctionalIntegration:
|
| | """Integration tests combining multiple functional components."""
|
| |
|
| | def test_full_pipeline(self):
|
| | """Test full pipeline: decomposition → multi-ternary forward."""
|
| |
|
| | in_features, out_features = 256, 512
|
| | W_dense = torch.randn(out_features, in_features)
|
| |
|
| |
|
| | k = 4
|
| | W_ternary, gammas = greedy_ternary_decomposition(W_dense, k)
|
| |
|
| |
|
| | batch_size = 16
|
| | x = torch.randn(batch_size, in_features)
|
| | bias = torch.randn(out_features)
|
| |
|
| | output = multi_ternary_linear_python(x, W_ternary, gammas, bias)
|
| |
|
| |
|
| | assert output.shape == (batch_size, out_features)
|
| | assert torch.isfinite(output).all()
|
| |
|
| |
|
| | output_dense = torch.matmul(x, W_dense.t()) + bias
|
| |
|
| |
|
| | relative_error = torch.norm(output - output_dense) / torch.norm(output_dense)
|
| | assert relative_error < 1.0
|
| |
|
| | def test_bitlinear_with_activation_quant(self):
|
| | """Test combining bitlinear with activation quantization."""
|
| | batch_size, in_features, out_features = 8, 128, 256
|
| |
|
| |
|
| | x = torch.randn(batch_size, in_features)
|
| | W_ternary = torch.randint(-1, 2, (out_features, in_features)).float()
|
| | gamma = torch.ones(out_features)
|
| |
|
| |
|
| | x_quant = activation_quant(x, bits=8)
|
| |
|
| |
|
| | output = bitlinear_python(x_quant, W_ternary, gamma)
|
| |
|
| |
|
| | assert output.shape == (batch_size, out_features)
|
| | assert torch.isfinite(output).all()
|
| |
|
| | def test_multi_ternary_end_to_end(self):
|
| | """Test multi-ternary from weight decomposition to forward pass."""
|
| |
|
| | W = torch.randn(64, 128) * 0.1
|
| | x = torch.randn(4, 128)
|
| |
|
| |
|
| | for k in [1, 2, 4]:
|
| | W_ternary, gammas = greedy_ternary_decomposition(W, k)
|
| | output = multi_ternary_linear_python(x, W_ternary, gammas, bias=None)
|
| |
|
| |
|
| | assert output.shape == (4, 64)
|
| | assert torch.isfinite(output).all()
|
| |
|
| |
|
| | W_reconstructed = torch.zeros_like(W)
|
| | for i in range(k):
|
| | W_reconstructed += gammas[i].unsqueeze(1) * W_ternary[i]
|
| |
|
| |
|
| | output_expected = torch.matmul(x, W_reconstructed.t())
|
| |
|
| |
|
| | assert torch.allclose(output, output_expected, atol=1e-4)
|
| |
|