| | """
|
| | BitLinear layer implementations.
|
| |
|
| | This module provides nn.Module wrappers around the functional implementations,
|
| | providing a drop-in replacement for nn.Linear with ternary weights.
|
| | """
|
| |
|
| | import math
|
| | import torch
|
| | import torch.nn as nn
|
| | import torch.nn.functional as F
|
| | from typing import Optional
|
| |
|
| | from .functional import (
|
| | bitlinear_python,
|
| | greedy_ternary_decomposition,
|
| | multi_ternary_linear_python,
|
| | )
|
| | from .quantization import weight_to_ternary
|
| |
|
| |
|
| | class BitLinear(nn.Module):
|
| | """
|
| | BitLinear layer: drop-in replacement for nn.Linear with ternary weights.
|
| |
|
| | This layer uses ternary weights ({-1, 0, +1}) instead of full-precision
|
| | weights, achieving ~20x memory compression while maintaining competitive
|
| | performance on Transformer models.
|
| |
|
| | Interface matches nn.Linear:
|
| | - Same initialization arguments (in_features, out_features, bias)
|
| | - Same forward signature
|
| | - Can replace nn.Linear in existing architectures
|
| |
|
| | Example:
|
| | >>> # Standard Linear
|
| | >>> linear = nn.Linear(512, 512)
|
| | >>> # BitLinear replacement
|
| | >>> bitlinear = BitLinear(512, 512)
|
| | >>> x = torch.randn(32, 128, 512)
|
| | >>> output = bitlinear(x) # Same interface
|
| |
|
| | Notes:
|
| | - Weights are quantized to ternary on initialization or conversion
|
| | - Stores ternary weights + scaling factors (gamma)
|
| | - Forward pass uses efficient ternary matrix multiplication
|
| | - Can be trained with QAT (Quantization-Aware Training)
|
| |
|
| | Attributes:
|
| | in_features: Input dimension
|
| | out_features: Output dimension
|
| | W_ternary: Ternary weight matrix [out_features, in_features]
|
| | gamma: Per-output scaling factors [out_features]
|
| | bias: Optional bias term [out_features]
|
| | """
|
| |
|
| | def __init__(
|
| | self,
|
| | in_features: int,
|
| | out_features: int,
|
| | bias: bool = True,
|
| | device: Optional[torch.device] = None,
|
| | dtype: Optional[torch.dtype] = None,
|
| | ):
|
| | """
|
| | Initialize BitLinear layer.
|
| |
|
| | Args:
|
| | in_features: Size of each input sample
|
| | out_features: Size of each output sample
|
| | bias: If True, add learnable bias (default: True)
|
| | device: Device to place parameters on
|
| | dtype: Data type for parameters
|
| |
|
| | TODO:
|
| | - Initialize dense weights using standard initialization (e.g., kaiming_uniform_)
|
| | - Convert to ternary using weight_to_ternary()
|
| | - Register W_ternary and gamma as parameters or buffers
|
| | - Initialize bias if needed
|
| | - Decide on training strategy (fixed ternary vs. QAT)
|
| | """
|
| | super().__init__()
|
| |
|
| | self.in_features = in_features
|
| | self.out_features = out_features
|
| |
|
| |
|
| |
|
| | self.W_ternary = nn.Parameter(torch.zeros(out_features, in_features))
|
| | self.gamma = nn.Parameter(torch.ones(out_features))
|
| |
|
| |
|
| | if bias:
|
| | self.bias = nn.Parameter(torch.zeros(out_features))
|
| | else:
|
| | self.register_parameter('bias', None)
|
| |
|
| |
|
| | self.reset_parameters()
|
| |
|
| | def reset_parameters(self) -> None:
|
| | """
|
| | Initialize layer parameters.
|
| |
|
| | Strategy:
|
| | 1. Initialize dense weights using standard scheme (kaiming_uniform_)
|
| | 2. Quantize to ternary using weight_to_ternary()
|
| | 3. Store ternary weights and scaling factors
|
| | """
|
| |
|
| | W_dense = torch.empty(self.out_features, self.in_features)
|
| | nn.init.kaiming_uniform_(W_dense, a=math.sqrt(5))
|
| |
|
| |
|
| | W_ternary, gamma = weight_to_ternary(W_dense, per_channel=True)
|
| | self.W_ternary.data.copy_(W_ternary)
|
| | self.gamma.data.copy_(gamma)
|
| |
|
| |
|
| | if self.bias is not None:
|
| | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(W_dense)
|
| | bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
| | nn.init.uniform_(self.bias, -bound, bound)
|
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| | """
|
| | Forward pass through BitLinear layer.
|
| |
|
| | Args:
|
| | x: Input tensor of shape [..., in_features]
|
| |
|
| | Returns:
|
| | Output tensor of shape [..., out_features]
|
| | """
|
| | return bitlinear_python(x, self.W_ternary, self.gamma, self.bias)
|
| |
|
| | @classmethod
|
| | def from_linear(cls, linear: nn.Linear) -> 'BitLinear':
|
| | """
|
| | Convert a standard nn.Linear layer to BitLinear.
|
| |
|
| | This allows converting pre-trained models to use ternary weights.
|
| |
|
| | Args:
|
| | linear: Standard nn.Linear layer to convert
|
| |
|
| | Returns:
|
| | BitLinear layer with quantized weights
|
| |
|
| | Example:
|
| | >>> linear = nn.Linear(512, 512)
|
| | >>> # ... train linear ...
|
| | >>> bitlinear = BitLinear.from_linear(linear)
|
| | """
|
| |
|
| | bitlinear = cls(
|
| | linear.in_features,
|
| | linear.out_features,
|
| | bias=linear.bias is not None,
|
| | device=linear.weight.device,
|
| | dtype=linear.weight.dtype,
|
| | )
|
| |
|
| |
|
| | W_ternary, gamma = weight_to_ternary(linear.weight.data, per_channel=True)
|
| | bitlinear.W_ternary.data.copy_(W_ternary)
|
| | bitlinear.gamma.data.copy_(gamma)
|
| |
|
| |
|
| | if linear.bias is not None:
|
| | bitlinear.bias.data.copy_(linear.bias.data)
|
| |
|
| | return bitlinear
|
| |
|
| | def extra_repr(self) -> str:
|
| | """String representation for print()."""
|
| | return f'in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}'
|
| |
|
| |
|
| | class MultiTernaryLinear(nn.Module):
|
| | """
|
| | Multi-component ternary linear layer.
|
| |
|
| | Represents a linear layer as a sum of k ternary components:
|
| | output = sum_{i=1}^k (x @ W_i^T * gamma_i) + bias
|
| |
|
| | This provides better approximation of dense weights compared to single
|
| | ternary quantization, at the cost of k× more computation.
|
| |
|
| | References:
|
| | - JMLR paper on ternary representations: https://jmlr.org/papers/volume26/24-2050/24-2050.pdf
|
| | - Greedy ternary decomposition for neural networks
|
| |
|
| | Attributes:
|
| | in_features: Input dimension
|
| | out_features: Output dimension
|
| | k: Number of ternary components
|
| | W_ternary: Stacked ternary weights [k, out_features, in_features]
|
| | gammas: Stacked scaling factors [k, out_features]
|
| | bias: Optional bias term [out_features]
|
| |
|
| | Example:
|
| | >>> # Single ternary component (equivalent to BitLinear)
|
| | >>> layer = MultiTernaryLinear(512, 512, k=1)
|
| | >>> # Multiple components for better approximation
|
| | >>> layer = MultiTernaryLinear(512, 512, k=4)
|
| | """
|
| |
|
| | def __init__(
|
| | self,
|
| | in_features: int,
|
| | out_features: int,
|
| | k: int = 2,
|
| | bias: bool = True,
|
| | device: Optional[torch.device] = None,
|
| | dtype: Optional[torch.dtype] = None,
|
| | ):
|
| | """
|
| | Initialize MultiTernaryLinear layer.
|
| |
|
| | Args:
|
| | in_features: Size of each input sample
|
| | out_features: Size of each output sample
|
| | k: Number of ternary components (typically 2-4)
|
| | bias: If True, add learnable bias
|
| | device: Device to place parameters on
|
| | dtype: Data type for parameters
|
| |
|
| | TODO:
|
| | - Initialize dense weights
|
| | - Apply greedy_ternary_decomposition with k components
|
| | - Store stacked ternary weights and gammas
|
| | - Initialize bias
|
| | """
|
| | super().__init__()
|
| |
|
| | self.in_features = in_features
|
| | self.out_features = out_features
|
| | self.k = k
|
| |
|
| |
|
| |
|
| | self.W_ternary = nn.Parameter(torch.zeros(k, out_features, in_features))
|
| | self.gammas = nn.Parameter(torch.ones(k, out_features))
|
| |
|
| | if bias:
|
| | self.bias = nn.Parameter(torch.zeros(out_features))
|
| | else:
|
| | self.register_parameter('bias', None)
|
| |
|
| |
|
| | self.reset_parameters()
|
| |
|
| | def reset_parameters(self) -> None:
|
| | """
|
| | Initialize layer parameters using greedy ternary decomposition.
|
| | """
|
| |
|
| | W_dense = torch.empty(self.out_features, self.in_features)
|
| | nn.init.kaiming_uniform_(W_dense, a=math.sqrt(5))
|
| |
|
| |
|
| | W_ternary_list, gamma_list = greedy_ternary_decomposition(W_dense, self.k)
|
| |
|
| |
|
| | self.W_ternary.data.copy_(W_ternary_list)
|
| | self.gammas.data.copy_(gamma_list)
|
| |
|
| |
|
| | if self.bias is not None:
|
| | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(W_dense)
|
| | bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
| | nn.init.uniform_(self.bias, -bound, bound)
|
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| | """
|
| | Forward pass through multi-ternary layer.
|
| |
|
| | Args:
|
| | x: Input tensor of shape [..., in_features]
|
| |
|
| | Returns:
|
| | Output tensor of shape [..., out_features]
|
| | """
|
| | return multi_ternary_linear_python(x, self.W_ternary, self.gammas, self.bias)
|
| |
|
| | @classmethod
|
| | def from_linear(cls, linear: nn.Linear, k: int = 2) -> 'MultiTernaryLinear':
|
| | """
|
| | Convert nn.Linear to MultiTernaryLinear using greedy decomposition.
|
| |
|
| | Args:
|
| | linear: Standard nn.Linear layer
|
| | k: Number of ternary components
|
| |
|
| | Returns:
|
| | MultiTernaryLinear layer
|
| | """
|
| |
|
| | multi_ternary = cls(
|
| | linear.in_features,
|
| | linear.out_features,
|
| | k=k,
|
| | bias=linear.bias is not None,
|
| | device=linear.weight.device,
|
| | dtype=linear.weight.dtype,
|
| | )
|
| |
|
| |
|
| | W_ternary_list, gamma_list = greedy_ternary_decomposition(linear.weight.data, k)
|
| | multi_ternary.W_ternary.data.copy_(W_ternary_list)
|
| | multi_ternary.gammas.data.copy_(gamma_list)
|
| |
|
| |
|
| | if linear.bias is not None:
|
| | multi_ternary.bias.data.copy_(linear.bias.data)
|
| |
|
| | return multi_ternary
|
| |
|
| | def extra_repr(self) -> str:
|
| | """String representation."""
|
| | return f'in_features={self.in_features}, out_features={self.out_features}, k={self.k}, bias={self.bias is not None}'
|
| |
|
| |
|
| | def convert_linear_to_bitlinear(
|
| | module: nn.Module,
|
| | inplace: bool = True,
|
| | ) -> nn.Module:
|
| | """
|
| | Recursively convert all nn.Linear layers in a module to BitLinear.
|
| |
|
| | This utility function walks through a model and replaces all Linear layers
|
| | with BitLinear layers, useful for converting pre-trained models.
|
| |
|
| | Args:
|
| | module: PyTorch module (e.g., a Transformer model)
|
| | inplace: If True, modify module in place; if False, return a copy
|
| |
|
| | Returns:
|
| | Module with Linear layers replaced by BitLinear
|
| |
|
| | Example:
|
| | >>> model = transformers.GPT2Model.from_pretrained('gpt2')
|
| | >>> model = convert_linear_to_bitlinear(model)
|
| | >>> # All Linear layers are now BitLinear
|
| | """
|
| | if not inplace:
|
| | import copy
|
| | module = copy.deepcopy(module)
|
| |
|
| |
|
| | for name, child in module.named_children():
|
| | if isinstance(child, nn.Linear):
|
| |
|
| | setattr(module, name, BitLinear.from_linear(child))
|
| | else:
|
| |
|
| | convert_linear_to_bitlinear(child, inplace=True)
|
| |
|
| | return module
|
| |
|