| | |
| | |
| | |
| | |
| | |
| |
|
| | """ |
| | Differentiable quantizer based on scaled noise injection. |
| | """ |
| | from dataclasses import dataclass |
| | import math |
| | import typing as tp |
| |
|
| | import torch |
| |
|
| | from .base import BaseQuantizer |
| | from .uniform import uniform_quantize, uniform_unquantize |
| | from .utils import simple_repr |
| |
|
| |
|
| | class DiffQuantizer(BaseQuantizer): |
| | @dataclass |
| | class _QuantizedParam(BaseQuantizer._QuantizedParam): |
| | logit: torch.nn.Parameter |
| |
|
| | def __init__(self, model: torch.nn.Module, min_size: float = 0.01, float16: bool = False, |
| | group_size: int = 1, min_bits: float = 2, max_bits: float = 15, |
| | param="bits", noise="gaussian", |
| | init_bits: float = 8, extra_bits: float = 0, suffix: str = "_diffq", |
| | exclude: tp.List[str] = [], detect_bound: bool = True): |
| | """ |
| | Differentiable quantizer based on scaled noise injection. |
| | For every parameter `p` in the model, this introduces a number of bits parameter |
| | `b` with the same dimensions (when group_size = 1). |
| | Before each forward, `p` is replaced by `p + U` |
| | with U uniform iid noise with range [-d/2, d/2], with `d` the uniform quantization |
| | step for `b` bits. |
| | This noise approximates the quantization noise in a differentiable manner, both |
| | with respect to the unquantized parameter `p` and the number of bits `b`. |
| | |
| | At eveluation (as detected with `model.eval()`), the model is replaced |
| | by its true quantized version, and restored when going back to training. |
| | |
| | When doing actual quantization (for serialization, or evaluation), |
| | the number of bits is rounded to the nearest integer, and needs to be stored along. |
| | This will cost a few bits per dimension. To reduce this cost, one can use `group_size`, |
| | which will use a single noise level for multiple weight entries. |
| | |
| | You can use the `DiffQuantizer.model_size` method to get a differentiable estimate of the |
| | model size in MB. You can then use this estimate as a penalty in your training loss. |
| | |
| | Args: |
| | model (torch.nn.Module): model to quantize |
| | min_size (float): minimum size in MB of a parameter to be quantized. |
| | float16 (bool): if a layer is smaller than min_size, should we still do float16? |
| | group_size (int): weight entries are groupped together to reduce the number |
| | of noise scales to store. This should divide the size of all parameters |
| | bigger than min_size. |
| | min_bits (float): minimal number of bits. |
| | max_bits (float): maximal number of bits. |
| | init_bits (float): initial number of bits. |
| | extra_bits (float): extra bits to add for actual quantization (before roundoff). |
| | suffix (str): suffix used for the name of the extra noise scale parameters. |
| | exclude (list[str]): list of patterns used to match parameters to exclude. |
| | For instance `['bias']` to exclude all bias terms. |
| | detect_bound (bool): if True, will detect bound parameters and reuse |
| | the same quantized tensor for both, as well as the same number of bits. |
| | |
| | ..Warning:: |
| | You must call `model.training()` and `model.eval()` for `DiffQuantizer` work properly. |
| | |
| | """ |
| | self.group_size = group_size |
| | self.min_bits = min_bits |
| | self.max_bits = max_bits |
| | self.init_bits = init_bits |
| | self.extra_bits = extra_bits |
| | self.suffix = suffix |
| | self.param = param |
| | self.noise = noise |
| | assert noise in ["gaussian", "uniform"] |
| | self._optimizer_setup = False |
| |
|
| | self._min_noise = 1 / (2 ** self.max_bits - 1) |
| | self._max_noise = 1 / (2 ** self.min_bits - 1) |
| |
|
| | assert group_size >= 0 |
| | assert min_bits < init_bits < max_bits, \ |
| | "init_bits must be between min_bits and max_bits excluded3" |
| |
|
| | for name, _ in model.named_parameters(): |
| | if name.endswith(suffix): |
| | raise RuntimeError("The model already has some noise scales parameters, " |
| | "maybe you used twice a DiffQuantizer on the same model?.") |
| |
|
| | super().__init__(model, min_size, float16, exclude, detect_bound) |
| |
|
| | def _get_bits(self, logit: torch.Tensor): |
| | if self.param == "noise": |
| | return torch.log2(1 + 1 / self._get_noise_scale(logit)) |
| | else: |
| | t = torch.sigmoid(logit) |
| | return self.max_bits * t + (1 - t) * self.min_bits |
| |
|
| | def _get_noise_scale(self, logit: torch.Tensor): |
| | if self.param == "noise": |
| | t = torch.sigmoid(logit) |
| | return torch.exp(t * math.log(self._min_noise) + (1 - t) * math.log(self._max_noise)) |
| | else: |
| | return 1 / (2 ** self._get_bits(logit) - 1) |
| |
|
| | def _register_param(self, name, param, module, other): |
| | if other is not None: |
| | return self.__class__._QuantizedParam( |
| | name=name, param=param, module=module, logit=other.logit, other=other) |
| | assert self.group_size == 0 or param.numel() % self.group_size == 0 |
| | |
| | if self.param == "noise": |
| | noise_scale = 1 / (2 ** self.init_bits - 1) |
| | t = (math.log(noise_scale) - math.log(self._max_noise)) / ( |
| | math.log(self._min_noise) - math.log(self._max_noise)) |
| | else: |
| | t = (self.init_bits - self.min_bits) / (self.max_bits - self.min_bits) |
| | assert 0 < t < 1 |
| | logit = torch.logit(torch.tensor(float(t))) |
| | assert abs(self._get_bits(logit) - self.init_bits) < 1e-5 |
| | if self.group_size > 0: |
| | nparam = param.numel() // self.group_size |
| | else: |
| | nparam = 1 |
| | logit = torch.nn.Parameter( |
| | torch.full( |
| | (nparam,), |
| | logit, |
| | device=param.device)) |
| | module.register_parameter(name + self.suffix, logit) |
| | return self.__class__._QuantizedParam( |
| | name=name, param=param, module=module, logit=logit, other=None) |
| |
|
| | def clear_optimizer(self, optimizer: torch.optim.Optimizer): |
| | params = [qp.logit for qp in self._qparams] |
| |
|
| | for group in optimizer.param_groups: |
| | new_params = [] |
| | for q in list(group["params"]): |
| | matched = False |
| | for p in params: |
| | if p is q: |
| | matched = True |
| | if not matched: |
| | new_params.append(q) |
| | group["params"][:] = new_params |
| |
|
| | def setup_optimizer(self, optimizer: torch.optim.Optimizer, |
| | lr: float = 1e-3, **kwargs): |
| | """ |
| | Setup the optimizer to tune the number of bits. In particular, this will deactivate |
| | weight decay for the bits parameters. |
| | |
| | Args: |
| | optimizer (torch.Optimizer): optimizer to use. |
| | lr (float): specific learning rate for the bits parameters. 1e-3 |
| | is perfect for Adam.,w |
| | kwargs (dict): overrides for other optimization parameters for the bits. |
| | """ |
| | assert not self._optimizer_setup |
| | self._optimizer_setup = True |
| |
|
| | params = [qp.logit for qp in self._qparams] |
| |
|
| | for group in optimizer.param_groups: |
| | for q in list(group["params"]): |
| | for p in params: |
| | if p is q: |
| | raise RuntimeError("You should create the optimizer " |
| | "before the quantizer!") |
| |
|
| | group = {"params": params, "lr": lr, "weight_decay": 0} |
| | group.update(kwargs) |
| | optimizer.add_param_group(group) |
| |
|
| | def no_optimizer(self): |
| | """ |
| | Call this if you do not want to use an optimizer. |
| | """ |
| | self._optimizer_setup = True |
| |
|
| | def check_unused(self): |
| | for qparam in self._qparams: |
| | if qparam.other is not None: |
| | continue |
| | grad = qparam.param.grad |
| | if grad is None or (grad == 0).all(): |
| | if qparam.logit.grad is not None: |
| | qparam.logit.grad.data.zero_() |
| |
|
| | def model_size(self, exact=False): |
| | """ |
| | Differentiable estimate of the model size. |
| | The size is returned in MB. |
| | |
| | If `exact` is True, then the output is no longer differentiable but |
| | reflect exactly an achievable size, even without compression, |
| | i.e.same as returned by `naive_model_size()`. |
| | """ |
| | total = super().model_size() |
| | subtotal = 0 |
| | for qparam in self._qparams: |
| | |
| | if qparam.other is not None: |
| | continue |
| | bits = self.extra_bits + self._get_bits(qparam.logit) |
| | if exact: |
| | bits = bits.round().clamp(1, 15) |
| | if self.group_size == 0: |
| | group_size = qparam.param.numel() |
| | else: |
| | group_size = self.group_size |
| | subtotal += group_size * bits.sum() |
| | subtotal += 2 * 32 |
| |
|
| | |
| | bits_bits = math.ceil(math.log2(1 + (bits.max().round().item() - self.min_bits))) |
| | subtotal += 8 |
| | subtotal += bits_bits * bits.numel() |
| |
|
| | subtotal /= 2 ** 20 * 8 |
| | return total + subtotal |
| |
|
| | def true_model_size(self): |
| | """ |
| | Naive model size without zlib compression. |
| | """ |
| | return self.model_size(exact=True).item() |
| |
|
| | def _pre_forward_train(self): |
| | if not self._optimizer_setup: |
| | raise RuntimeError("You must call `setup_optimizer()` on your optimizer " |
| | "before starting training.") |
| | for qparam in self._qparams: |
| | if qparam.other is not None: |
| | noisy = qparam.other.module._parameters[qparam.other.name] |
| | else: |
| | bits = self._get_bits(qparam.logit)[:, None] |
| | if self.group_size == 0: |
| | p_flat = qparam.param.view(-1) |
| | else: |
| | p_flat = qparam.param.view(-1, self.group_size) |
| | scale = p_flat.max() - p_flat.min() |
| | unit = 1 / (2**bits - 1) |
| | if self.noise == "uniform": |
| | noise_source = (torch.rand_like(p_flat) - 0.5) |
| | elif self.noise == "gaussian": |
| | noise_source = torch.randn_like(p_flat) / 2 |
| | noise = scale * unit * noise_source |
| | noisy = p_flat + noise |
| | |
| | qparam.module._parameters[qparam.name] = noisy.view_as(qparam.param) |
| | return True |
| |
|
| | def _post_forward_train(self): |
| | for qparam in self._qparams: |
| | qparam.module._parameters[qparam.name] = qparam.param |
| | return True |
| |
|
| | def _quantize_param(self, qparam: _QuantizedParam) -> tp.Any: |
| | bits = self.extra_bits + self._get_bits(qparam.logit) |
| | bits = bits.round().clamp(1, 15)[:, None].byte() |
| | if self.group_size == 0: |
| | p = qparam.param.data.view(-1) |
| | else: |
| | p = qparam.param.data.view(-1, self.group_size) |
| | levels, scales = uniform_quantize(p, bits) |
| | return levels, scales, bits |
| |
|
| | def _unquantize_param(self, qparam: _QuantizedParam, quantized: tp.Any) -> torch.Tensor: |
| | levels, param_scale, bits = quantized |
| | return uniform_unquantize(levels, param_scale, bits).view_as(qparam.param.data) |
| |
|
| | def detach(self): |
| | super().detach() |
| | for qparam in self._qparams: |
| | delattr(qparam.module, qparam.name + self.suffix) |
| |
|
| | def __repr__(self): |
| | return simple_repr(self) |
| |
|