Add AutoModel register and code polish

#9
by Seas0 - opened
Files changed (2) hide show
  1. config.json +1 -0
  2. modeling_stable_diffcoder.py +9 -6
config.json CHANGED
@@ -3,6 +3,7 @@
3
  "StableDiffcoderForCausalLM"
4
  ],
5
  "auto_map": {
 
6
  "AutoModelForCausalLM": "modeling_stable_diffcoder.StableDiffcoderForCausalLM"
7
  },
8
  "attention_bias": false,
 
3
  "StableDiffcoderForCausalLM"
4
  ],
5
  "auto_map": {
6
+ "AutoModel": "modeling_stable_diffcoder.StableDiffcoderForCausalLM",
7
  "AutoModelForCausalLM": "modeling_stable_diffcoder.StableDiffcoderForCausalLM"
8
  },
9
  "attention_bias": false,
modeling_stable_diffcoder.py CHANGED
@@ -158,12 +158,15 @@ class StableDiffcoderForCausalLM(LlamaForCausalLM):
158
  nfe = 0
159
  final_flag = False
160
  prefill_length = prompt_length // block_length * block_length
161
-
162
  if prefill_length > 0:
163
  cur_attn_mask = block_diffusion_attention_mask[
164
  ..., :prefill_length, :prefill_length
165
  ]
166
  # Fix 1: Explicitly pass cache_position for newer transformers prefill
 
 
 
167
  cache_pos = torch.arange(prefill_length, device=x.device)
168
  self(
169
  x[:, :prefill_length],
@@ -211,17 +214,17 @@ class StableDiffcoderForCausalLM(LlamaForCausalLM):
211
  remasking,
212
  mask_map,
213
  x[:, block_start:block_end],
214
- token_count.item() if threshold is None else None,
215
  threshold,
216
  shift=shift,
217
  )
218
  x[:, block_start:block_end][transfer_map] = x0[transfer_map]
219
 
220
  if (x[:, block_start:block_end] == mask_id).sum() == 0:
221
-
222
  # Fix 2: Calculate where the generated tokens ACTUALLY start in this block
223
  gen_start = max(block_start, prompt_length)
224
-
225
  if (
226
  eos_id is not None
227
  and gen_start < block_end
@@ -232,7 +235,7 @@ class StableDiffcoderForCausalLM(LlamaForCausalLM):
232
  eos_pos = (x[:, gen_start:block_end] == eos_id).nonzero(as_tuple=True)[1][0].item() + gen_start
233
  x[0, eos_pos:] = eos_id
234
  break
235
-
236
  nfe += 1
237
  self(
238
  x[:, block_start:block_end],
@@ -243,7 +246,7 @@ class StableDiffcoderForCausalLM(LlamaForCausalLM):
243
  use_cache=True,
244
  cache_position=replace_position.nonzero(as_tuple=True)[1],
245
  )
246
- break
247
 
248
  if final_flag:
249
  break
 
158
  nfe = 0
159
  final_flag = False
160
  prefill_length = prompt_length // block_length * block_length
161
+
162
  if prefill_length > 0:
163
  cur_attn_mask = block_diffusion_attention_mask[
164
  ..., :prefill_length, :prefill_length
165
  ]
166
  # Fix 1: Explicitly pass cache_position for newer transformers prefill
167
+ # actually not necessary since transformers will automatically generate it for prefilling
168
+ # if unspecified, but the official `generate` method does pass it,
169
+ # so we follow that for consistency and to avoid potential issues in future transformers updates
170
  cache_pos = torch.arange(prefill_length, device=x.device)
171
  self(
172
  x[:, :prefill_length],
 
214
  remasking,
215
  mask_map,
216
  x[:, block_start:block_end],
217
+ token_count.item() if threshold is None else None,
218
  threshold,
219
  shift=shift,
220
  )
221
  x[:, block_start:block_end][transfer_map] = x0[transfer_map]
222
 
223
  if (x[:, block_start:block_end] == mask_id).sum() == 0:
224
+
225
  # Fix 2: Calculate where the generated tokens ACTUALLY start in this block
226
  gen_start = max(block_start, prompt_length)
227
+
228
  if (
229
  eos_id is not None
230
  and gen_start < block_end
 
235
  eos_pos = (x[:, gen_start:block_end] == eos_id).nonzero(as_tuple=True)[1][0].item() + gen_start
236
  x[0, eos_pos:] = eos_id
237
  break
238
+
239
  nfe += 1
240
  self(
241
  x[:, block_start:block_end],
 
246
  use_cache=True,
247
  cache_position=replace_position.nonzero(as_tuple=True)[1],
248
  )
249
+ break
250
 
251
  if final_flag:
252
  break