| | from safetensors.torch import load_file |
| | import torch |
| | from tqdm import tqdm |
| |
|
| | __all__ = [ |
| | 'flux_load_lora' |
| | ] |
| |
|
| |
|
| | def is_int(d): |
| | try: |
| | d = int(d) |
| | return True |
| | except Exception as e: |
| | return False |
| |
|
| |
|
| | def flux_load_lora(self, lora_file, lora_weight=1.0): |
| | device = self.transformer.device |
| |
|
| | |
| | state_dict, network_alphas = self.lora_state_dict(lora_file, return_alphas=True) |
| | state_dict = {k:v.to(device) for k,v in state_dict.items()} |
| | |
| | model = self.transformer |
| | keys = list(state_dict.keys()) |
| | keys = [k for k in keys if k.startswith('transformer.')] |
| |
|
| | for k_lora in tqdm(keys, total=len(keys), desc=f"loading lora in transformer ..."): |
| | v_lora = state_dict[k_lora] |
| |
|
| | |
| | if '.lora_A.weight' in k_lora: |
| | continue |
| | if '.alpha' in k_lora: |
| | continue |
| |
|
| | k_lora_name = k_lora.replace("transformer.", "") |
| | k_lora_name = k_lora_name.replace(".lora_B.weight", "") |
| | attr_name_list = k_lora_name.split('.') |
| |
|
| | cur_attr = model |
| | latest_attr_name = '' |
| | for idx in range(0, len(attr_name_list)): |
| | attr_name = attr_name_list[idx] |
| | if is_int(attr_name): |
| | cur_attr = cur_attr[int(attr_name)] |
| | latest_attr_name = '' |
| | else: |
| | try: |
| | if latest_attr_name != '': |
| | cur_attr = cur_attr.__getattr__(f"{latest_attr_name}.{attr_name}") |
| | else: |
| | cur_attr = cur_attr.__getattr__(attr_name) |
| | latest_attr_name = '' |
| | except Exception as e: |
| | if latest_attr_name != '': |
| | latest_attr_name = f"{latest_attr_name}.{attr_name}" |
| | else: |
| | latest_attr_name = attr_name |
| |
|
| | up_w = v_lora |
| | down_w = state_dict[k_lora.replace('.lora_B.weight', '.lora_A.weight')] |
| |
|
| | |
| | einsum_a = f"ijabcdefg" |
| | einsum_b = f"jkabcdefg" |
| | einsum_res = f"ikabcdefg" |
| | length_shape = len(up_w.shape) |
| | einsum_str = f"{einsum_a[:length_shape]},{einsum_b[:length_shape]}->{einsum_res[:length_shape]}" |
| | dtype = cur_attr.weight.data.dtype |
| | d_w = torch.einsum(einsum_str, up_w.to(torch.float32), down_w.to(torch.float32)).to(dtype) |
| | cur_attr.weight.data = cur_attr.weight.data + d_w * lora_weight |
| |
|
| |
|
| |
|
| | |
| | raw_state_dict = load_file(lora_file) |
| | raw_state_dict = {k:v.to(device) for k,v in raw_state_dict.items()} |
| |
|
| | |
| | state_dict = {k:v for k,v in raw_state_dict.items() if 'lora_te1_' in k} |
| | model = self.text_encoder |
| | keys = list(state_dict.keys()) |
| | keys = [k for k in keys if k.startswith('lora_te1_')] |
| |
|
| | for k_lora in tqdm(keys, total=len(keys), desc=f"loading lora in text_encoder ..."): |
| | v_lora = state_dict[k_lora] |
| |
|
| | |
| | if '.lora_down.weight' in k_lora: |
| | continue |
| | if '.alpha' in k_lora: |
| | continue |
| |
|
| | k_lora_name = k_lora.replace("lora_te1_", "") |
| | k_lora_name = k_lora_name.replace(".lora_up.weight", "") |
| | attr_name_list = k_lora_name.split('_') |
| |
|
| | cur_attr = model |
| | latest_attr_name = '' |
| | for idx in range(0, len(attr_name_list)): |
| | attr_name = attr_name_list[idx] |
| | if is_int(attr_name): |
| | cur_attr = cur_attr[int(attr_name)] |
| | latest_attr_name = '' |
| | else: |
| | try: |
| | if latest_attr_name != '': |
| | cur_attr = cur_attr.__getattr__(f"{latest_attr_name}_{attr_name}") |
| | else: |
| | cur_attr = cur_attr.__getattr__(attr_name) |
| | latest_attr_name = '' |
| | except Exception as e: |
| | if latest_attr_name != '': |
| | latest_attr_name = f"{latest_attr_name}_{attr_name}" |
| | else: |
| | latest_attr_name = attr_name |
| |
|
| | up_w = v_lora |
| | down_w = state_dict[k_lora.replace('.lora_up.weight', '.lora_down.weight')] |
| | |
| | alpha = state_dict.get(k_lora.replace('.lora_up.weight', '.alpha'), None) |
| | if alpha is None: |
| | lora_scale = 1 |
| | else: |
| | rank = up_w.shape[1] |
| | lora_scale = alpha / rank |
| | |
| | |
| | einsum_a = f"ijabcdefg" |
| | einsum_b = f"jkabcdefg" |
| | einsum_res = f"ikabcdefg" |
| | length_shape = len(up_w.shape) |
| | einsum_str = f"{einsum_a[:length_shape]},{einsum_b[:length_shape]}->{einsum_res[:length_shape]}" |
| | dtype = cur_attr.weight.data.dtype |
| | d_w = torch.einsum(einsum_str, up_w.to(torch.float32), down_w.to(torch.float32)).to(dtype) |
| | cur_attr.weight.data = cur_attr.weight.data + d_w * lora_scale * lora_weight |
| |
|
| |
|