| | import torch |
| |
|
| | import types |
| | from typing import Any, Dict, List, Optional, Tuple, Union |
| | import transformers |
| | from transformers import Cache, GenerationConfig |
| | import torch.nn as nn |
| | from transformers.modeling_utils import PreTrainedModel |
| |
|
| | from .functions_2_patch import _validate_model_kwargs, llama_atten_forward |
| | from .monkey_patching_utils import monkey_patching |
| | from .sep_cache_utils import SepCache |
| |
|
| |
|
| | UNSUPPORTED_GENERATION_ARGS = [ |
| | "cache_implementation", |
| | "cache_config", |
| | "return_legacy_cache", |
| | "num_beams", |
| | "compile_config", |
| | "assistant_model", |
| | ] |
| |
|
| |
|
| | def generate(model, |
| | |
| | init_cache_size: Union[int, List] = 4, |
| | sep_cache_size: Union[int, List] = 128, |
| | local_size: Union[int, List]=256, |
| | cache_size: Union[int, List]=512, |
| | SEP_ACCUMULATION: bool = True, |
| | USE_MAX_SEP_CACHE: bool = False, |
| | SEP_PADDING_IN_BATCH: bool = False, |
| | separator_token_ids: List[int] = None, |
| | PADDING_ID: int = None, |
| | |
| | |
| | past_tok_ids: List[torch.Tensor] = None, |
| | key_cache: List[torch.Tensor] = None, |
| | value_cache: List[torch.Tensor] = None, |
| | |
| | |
| | PRINT_KV_RATIO_INSIDE: bool = False, |
| | print_KV_inside_per_steps: int = 1000, |
| | _seen_tokens: int = 0, |
| | _kept_kv_ratio: List[Tuple[int]] = None, |
| | |
| | |
| | APPLY_PE_SHIFT: bool = False, |
| | APPLY_PES_INSIDE: bool = False, |
| | _shifted_position_ids: List[torch.Tensor] = None, |
| | _rope_unsqueeze_dim: int = 1, |
| | _rope_seq_dim: int=1, |
| | pe_scaling_factor:float = 1.0, |
| | pe_dim:int=128, |
| | max_position_embeddings: int = 8192, |
| | base: int=10000, |
| | |
| | |
| | k_seq_dim: int=2, |
| | v_seq_dim: int=2, |
| | layer_num: int = None, |
| | |
| | model_type: str = 'llama', |
| | device = None, |
| | |
| | |
| | monkey_patch_verbose: bool = False, |
| | |
| | **kwargs |
| | ): |
| | """Custom generate function for SepCache. |
| | |
| | A cache as described in the [SepLLM paper - ICML 2025](https://arxiv.org/abs/2412.12094). In the training phase, |
| | SepLLM condenses the segment information into the KV of the separator that divides the segment. In the inference phase, the |
| | corresponding SepCache only needs to store the KVs of initial tokens, separator tokens, and recent tokens for generation. |
| | |
| | It stores the Key and Value states as lists of tensors, two lists for each layer. The expected shape for each tensor is |
| | `[batch_size, num_heads, seq_len, head_dim]`. |
| | |
| | Frequently-Used Parameters: |
| | |
| | `init_cache_size: Union[int, List]`: |
| | The maximum number of KVs to be stored for initial tokens. |
| | In the paper, the hyperparameter `a` is an abbreviated alias for `self.init_cache_size`. |
| | |
| | `sep_cache_size: Union[int, List]`: |
| | The maximum number of KVs to be stored for separator tokens. |
| | In the paper, the hyperparameter `s` is an abbreviated alias for `self.sep_cache_size`. |
| | |
| | `local_size: Union[int, List]`: |
| | The maximum number of KVs to be stored for local tokens (i.e., sliding window). |
| | In the paper, the hyperparameter `w` is an abbreviated alias for `self.local_size`. |
| | |
| | `cache_size: Union[int, List]`: |
| | The maximum number of KVs to be stored for all the tokens, i.e., the size for the whole KV cache. |
| | In the paper, the hyperparameter `c` is an abbreviated alias for `self.cache_size`. |
| | |
| | Concerning these four parameters above: |
| | When a list is passed (its length must be `layer_num`), it represents different values for each layer. |
| | When an integer is passed, it means the setting is the same for all layers. |
| | |
| | |
| | `USE_MAX_SEP_CACHE: bool`: |
| | If True, it means we only keep at most `self.sep_cache_size` seperators' KVs. |
| | If the number exceeds this limit, older separator's KVs will be discarded, keeping only the most recent `self.sep_cache_size` KVs. |
| | In the paper, the hyperparameter `s` is an abbreviated alias for `self.sep_cache_size`. |
| | |
| | `separator_token_ids: List[int]`: |
| | The token ids of the separator tokens for the current model's tokenizer. |
| | We have some examples, such as the Llama-3 series models, where setting `model_type='llama'` allows you |
| | to skip setting `separator_token_ids` and `PADDING_ID` (SepCache will auto-fill them). |
| | |
| | `PADDING_ID: int`: |
| | The token id of the padding token. You can just set `PADDING_ID` to the id of "<|endoftext|>" token of the tokenizer for the pretrained model. |
| | |
| | Important Note: |
| | When `cache_size` and `local_size` are set to infinity (i.e., sufficiently large positive integers), and `USE_MAX_SEP_CACHE` is `False`, `SepCache` degenerates into a regular Cache. |
| | However, you must always ensure that `init_cache_size` + `sep_cache_size` + `local_size` + `left_padding_offset` < `cache_size`. |
| | Here, `left_padding_offset` denotes the number of padding tokens in the record with the largest left paddings within a runtime batch. `left_padding_offset` can only be determined at runtime. |
| | To guarantee the above inequality always holds during runtime, when setting, you can intentionally create a sufficient margin between both sides of the following inequality: |
| | `init_cache_size` + `sep_cache_size` + `local_size` < `cache_size`, i.e., `a`+`s`+`w`<`c` in the [SepLLM paper - ICML 2025] |
| | to leave room for `left_padding_offset`. |
| | |
| | Please refer to the `__init__` function's comments for more details on the parameters. |
| | |
| | Example: |
| | |
| | ```python |
| | >>> from transformers import AutoTokenizer, AutoModelForCausalLM, |
| | >>> from .sep_cache_utils import SepCache |
| | >>> import torch |
| | >>> from huggingface_hub import login |
| | >>> login("hf_xxxXXXxxx") |
| | |
| | |
| | >>> def to_cuda(a_dict: dict) -> dict: |
| | >>> new_dict = {} |
| | >>> for k,v in a_dict.items(): |
| | >>> if isinstance(v, torch.Tensor): |
| | >>> new_dict[k] = v.cuda() |
| | >>> else: |
| | >>> new_dict[k] = v |
| | >>> return new_dict |
| | |
| | >>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct", attn_implementation="flash_attention_2", device_map="cuda:0") |
| | >>> model.bfloat16().cuda() |
| | >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct") |
| | >>> inputs = tokenizer(text="My name is Llama 3", return_tensors="pt") |
| | >>> inputs = to_cuda(inputs) |
| | >>> # Prepare a cache and pass it to model's forward; `layer_num` is the number of layers for the pretrained model. |
| | >>> past_key_values = SepCache(init_cache_size=4, sep_cache_size=128, local_size=256, cache_size=512, layer_num=32, USE_MAX_SEP_CACHE=True, model_type='llama') |
| | >>> # `separator_token_ids` and `PADDING_ID` must also be provided if you are not using `model_type='llama'` like this demo. |
| | >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) |
| | >>> outputs.past_key_values # access SepCache filled with keys/values |
| | SepCache() |
| | ``` |
| | |
| | ```python |
| | >>> ## When using the `update` function of SepCache to update the keys/values and the past token ids (necessary in SepCache), the current `input_ids` must also be provided. |
| | >>> key_states, value_states = past_key_values.update( |
| | key_states = key_states, |
| | value_states = value_states, |
| | input_ids = input_ids, |
| | layer_idx = layer_idx, |
| | PREFILLING_FLAG = q_len > 1, ## `q_len` is the sequence length of the current `query_states` |
| | ) |
| | |
| | ``` |
| | For detailed usage instructions, please refer to https://github.com/HKUDS/SepLLM |
| | """ |
| |
|
| | |
| | model_layers = monkey_patching(model, model_atten_forward=llama_atten_forward, verbose=monkey_patch_verbose) |
| |
|
| | |
| | |
| | generation_config = kwargs.get("generation_config") |
| | default_global_generation_config = GenerationConfig() |
| | default_model_generation_config = model.generation_config |
| | for arg in UNSUPPORTED_GENERATION_ARGS: |
| | has_custom_gen_config_arg = ( |
| | generation_config is not None |
| | |
| | and not ( |
| | getattr(default_model_generation_config, arg) == getattr(generation_config, arg) |
| | or getattr(default_global_generation_config, arg) == getattr(generation_config, arg) |
| | ) |
| | ) |
| | kwargs_has_arg = arg in kwargs and kwargs[arg] is not None |
| | if kwargs_has_arg or has_custom_gen_config_arg: |
| | raise ValueError( |
| | f"`{arg}` is set, but it's not supported in this custom generate function. List of " |
| | f"unsupported arguments: {UNSUPPORTED_GENERATION_ARGS}" |
| | ) |
| |
|
| | |
| | |
| | |
| | if model.config.is_encoder_decoder: |
| | raise ValueError("This custom generate function only works with decoder-only models") |
| |
|
| | |
| | |
| | kwargs.pop("custom_generate", None) |
| |
|
| |
|
| | sepllm_kwargs = {} |
| | sepllm_kwargs["input_ids"] = kwargs["input_ids"] |
| | kwargs["sepllm_kwargs"] = sepllm_kwargs |
| |
|
| | |
| | |
| | past_key_values = kwargs.pop("past_key_values", None) |
| | if past_key_values is None: |
| | past_key_values = SepCache( |
| | |
| | init_cache_size = init_cache_size, |
| | sep_cache_size = sep_cache_size, |
| | local_size = local_size, |
| | cache_size = cache_size, |
| | SEP_ACCUMULATION = SEP_ACCUMULATION, |
| | USE_MAX_SEP_CACHE = USE_MAX_SEP_CACHE, |
| | SEP_PADDING_IN_BATCH = SEP_PADDING_IN_BATCH, |
| | separator_token_ids = separator_token_ids, |
| | PADDING_ID = PADDING_ID, |
| |
|
| | |
| | past_tok_ids = past_tok_ids, |
| | key_cache = key_cache, |
| | value_cache = value_cache, |
| |
|
| | |
| | PRINT_KV_RATIO_INSIDE = PRINT_KV_RATIO_INSIDE, |
| | print_KV_inside_per_steps = print_KV_inside_per_steps, |
| | _seen_tokens = _seen_tokens, |
| | _kept_kv_ratio = _kept_kv_ratio, |
| | |
| | |
| | APPLY_PE_SHIFT = APPLY_PE_SHIFT, |
| | APPLY_PES_INSIDE = APPLY_PES_INSIDE, |
| | _shifted_position_ids = _shifted_position_ids, |
| | _rope_unsqueeze_dim = _rope_unsqueeze_dim, |
| | _rope_seq_dim =_rope_seq_dim, |
| | pe_scaling_factor = pe_scaling_factor, |
| | pe_dim = pe_dim, |
| | max_position_embeddings = max_position_embeddings, |
| | base = base, |
| | |
| | |
| | k_seq_dim = k_seq_dim, |
| | v_seq_dim = v_seq_dim, |
| | layer_num = len(model_layers), |
| |
|
| | model_type = model_type, |
| | device = device, |
| | ) |
| |
|
| | elif not isinstance(past_key_values, SepCache): |
| | raise ValueError(f"`past_key_values` must be a `SepCache` instance, got a {type(past_key_values)} instance") |
| |
|
| | |
| | kwargs["use_cache"] = True |
| | generation_outputs = model.generate(**kwargs, past_key_values=past_key_values) |
| | return generation_outputs |
| |
|