| | """
|
| | Memory usage benchmarking for BitLinear.
|
| |
|
| | This script measures actual memory usage and compression ratios for BitLinear
|
| | compared to standard nn.Linear layers.
|
| | """
|
| |
|
| | import torch
|
| | import torch.nn as nn
|
| | from bitlinear import BitLinear, MultiTernaryLinear, pack_ternary_base3, estimate_memory_savings
|
| | import sys
|
| |
|
| |
|
| | def get_tensor_memory_mb(tensor):
|
| | """Get memory usage of a tensor in MB."""
|
| | return tensor.element_size() * tensor.nelement() / (1024 ** 2)
|
| |
|
| |
|
| | def get_model_memory_mb(model):
|
| | """Get total memory usage of model parameters in MB."""
|
| | total_bytes = sum(p.element_size() * p.nelement() for p in model.parameters())
|
| | return total_bytes / (1024 ** 2)
|
| |
|
| |
|
| | def analyze_layer_memory(in_features, out_features):
|
| | """Analyze memory usage for a single layer."""
|
| |
|
| | print(f"\n{'=' * 100}")
|
| | print(f"Layer: {in_features} → {out_features}")
|
| | print(f"{'=' * 100}\n")
|
| |
|
| |
|
| | linear = nn.Linear(in_features, out_features, bias=True)
|
| | bitlinear = BitLinear.from_linear(linear)
|
| | multi_ternary = MultiTernaryLinear.from_linear(linear, k=2)
|
| |
|
| |
|
| | mem_linear = get_model_memory_mb(linear)
|
| |
|
| |
|
| | mem_bitlinear = get_model_memory_mb(bitlinear)
|
| |
|
| |
|
| | mem_multi = get_model_memory_mb(multi_ternary)
|
| |
|
| |
|
| | weights_count = in_features * out_features
|
| | packed_bytes = (weights_count + 4) // 5
|
| | bias_bytes = out_features * 4
|
| | gamma_bytes = out_features * 4
|
| | theoretical_packed_mb = (packed_bytes + bias_bytes + gamma_bytes) / (1024 ** 2)
|
| |
|
| |
|
| | compression_current = mem_linear / mem_bitlinear
|
| | compression_packed = mem_linear / theoretical_packed_mb
|
| |
|
| |
|
| | print(f"nn.Linear memory: {mem_linear:10.4f} MB")
|
| | print(f"BitLinear memory (current): {mem_bitlinear:10.4f} MB (ratio: {compression_current:5.2f}x)")
|
| | print(f"BitLinear memory (packed): {theoretical_packed_mb:10.4f} MB (ratio: {compression_packed:5.2f}x)")
|
| | print(f"MultiTernaryLinear memory (k=2): {mem_multi:10.4f} MB (ratio: {mem_linear/mem_multi:5.2f}x)")
|
| |
|
| |
|
| | print(f"\nPacking Test:")
|
| | print(f"-" * 100)
|
| |
|
| | W_ternary = bitlinear.W_ternary
|
| | packed, original_shape = pack_ternary_base3(W_ternary)
|
| |
|
| | unpacked_size_mb = get_tensor_memory_mb(W_ternary)
|
| | packed_size_mb = get_tensor_memory_mb(packed)
|
| | actual_compression = unpacked_size_mb / packed_size_mb
|
| |
|
| | print(f"Unpacked weights: {unpacked_size_mb:10.4f} MB")
|
| | print(f"Packed weights: {packed_size_mb:10.4f} MB")
|
| | print(f"Actual compression: {actual_compression:8.2f}x")
|
| |
|
| | return {
|
| | 'in_features': in_features,
|
| | 'out_features': out_features,
|
| | 'mem_linear': mem_linear,
|
| | 'mem_bitlinear': mem_bitlinear,
|
| | 'mem_packed': theoretical_packed_mb,
|
| | 'mem_multi': mem_multi,
|
| | 'compression_current': compression_current,
|
| | 'compression_packed': compression_packed,
|
| | }
|
| |
|
| |
|
| | def run_memory_benchmarks():
|
| | """Run comprehensive memory benchmarks."""
|
| |
|
| | print("=" * 100)
|
| | print("BitLinear Memory Benchmarks")
|
| | print("=" * 100)
|
| | print(f"\nPyTorch version: {torch.__version__}")
|
| |
|
| |
|
| | layer_sizes = [
|
| | (512, 512),
|
| | (768, 768),
|
| | (1024, 1024),
|
| | (2048, 2048),
|
| | (4096, 4096),
|
| | (768, 3072),
|
| | (1024, 4096),
|
| | ]
|
| |
|
| | results = []
|
| |
|
| | for in_features, out_features in layer_sizes:
|
| | result = analyze_layer_memory(in_features, out_features)
|
| | results.append(result)
|
| |
|
| |
|
| | print(f"\n\n{'=' * 100}")
|
| | print("Memory Compression Summary (Markdown Format)")
|
| | print(f"{'=' * 100}\n")
|
| |
|
| | print("| Layer Size | nn.Linear (MB) | BitLinear Current (MB) | BitLinear Packed (MB) | Compression (Packed) |")
|
| | print("|------------|----------------|------------------------|----------------------|----------------------|")
|
| |
|
| | for r in results:
|
| | print(f"| {r['in_features']}×{r['out_features']:<4} | {r['mem_linear']:14.4f} | "
|
| | f"{r['mem_bitlinear']:22.4f} | {r['mem_packed']:20.4f} | {r['compression_packed']:20.2f}x |")
|
| |
|
| |
|
| | print(f"\n{'=' * 100}")
|
| | print("Summary Statistics")
|
| | print(f"{'=' * 100}\n")
|
| |
|
| | avg_compression = sum(r['compression_packed'] for r in results) / len(results)
|
| | min_compression = min(r['compression_packed'] for r in results)
|
| | max_compression = max(r['compression_packed'] for r in results)
|
| |
|
| | print(f"Average compression ratio: {avg_compression:.2f}x")
|
| | print(f"Minimum compression ratio: {min_compression:.2f}x")
|
| | print(f"Maximum compression ratio: {max_compression:.2f}x")
|
| |
|
| |
|
| | print(f"\n{'=' * 100}")
|
| | print("Real-World Example: GPT-2 Style Transformer")
|
| | print(f"{'=' * 100}\n")
|
| |
|
| |
|
| | num_layers = 12
|
| | d_model = 768
|
| | d_ff = 3072
|
| |
|
| |
|
| | linear_per_layer = (4 * d_model * d_model) + (d_model * d_ff) + (d_ff * d_model)
|
| | linear_total = linear_per_layer * num_layers
|
| |
|
| |
|
| | linear_mem_mb = (linear_total * 4) / (1024 ** 2)
|
| | packed_mem_mb = ((linear_total + 4) // 5) / (1024 ** 2)
|
| |
|
| |
|
| | params_per_layer = (4 * d_model) + d_ff + d_model
|
| | gammas_per_layer = (4 * d_model) + d_ff + d_model
|
| | overhead_mb = ((params_per_layer + gammas_per_layer) * num_layers * 4) / (1024 ** 2)
|
| |
|
| | packed_total_mb = packed_mem_mb + overhead_mb
|
| | compression = linear_mem_mb / packed_total_mb
|
| |
|
| | print(f"Configuration: {num_layers} layers, d_model={d_model}, d_ff={d_ff}")
|
| | print(f"Total linear parameters: {linear_total:,}")
|
| | print(f"\nnn.Linear memory: {linear_mem_mb:10.2f} MB")
|
| | print(f"BitLinear packed: {packed_total_mb:10.2f} MB")
|
| | print(f"Memory saved: {linear_mem_mb - packed_total_mb:10.2f} MB")
|
| | print(f"Compression ratio: {compression:10.2f}x")
|
| |
|
| | print(f"\n{'=' * 100}")
|
| | print("Benchmark Complete!")
|
| | print(f"{'=' * 100}")
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | run_memory_benchmarks()
|
| |
|