BiliSakura commited on
Commit
c836618
·
verified ·
1 Parent(s): c7626bd

Upload folder using huggingface_hub

Browse files
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ __pycache__
2
+ wget-log*
3
+ *.pyc
PixNerd-XL-16-256/README.md ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PixNerd-XL-16-256
2
+
3
+ Self-contained PixNerd-XL/16 checkpoint inside [`BiliSakura/PixNerd-diffusers`](https://huggingface.co/BiliSakura/PixNerd-diffusers). Runtime dependencies: this folder + PyPI `diffusers`/`torch` only.
4
+
5
+ ## Hub path
6
+
7
+ `BiliSakura/PixNerd-diffusers/PixNerd-XL-16-256`
8
+
9
+ ## Layout
10
+
11
+ ```text
12
+ PixNerd-XL-16-256/
13
+ ├── pipeline.py
14
+ ├── model_index.json
15
+ ├── conversion_metadata.json
16
+ ├── transformer/
17
+ └── scheduler/
18
+ ```
19
+
20
+ ## Load
21
+
22
+ ```python
23
+ import torch
24
+ from diffusers import DiffusionPipeline
25
+
26
+ pipe = DiffusionPipeline.from_pretrained(
27
+ "BiliSakura/PixNerd-diffusers/PixNerd-XL-16-256",
28
+ trust_remote_code=True,
29
+ torch_dtype=torch.float32,
30
+ ).to("cuda")
31
+
32
+ images = pipe(
33
+ prompt=207,
34
+ height=256,
35
+ width=256,
36
+ num_inference_steps=25,
37
+ guidance_scale=4.0,
38
+ timeshift=3.0,
39
+ order=2,
40
+ ).images
41
+ ```
PixNerd-XL-16-256/model_index.json CHANGED
@@ -1,12 +1,15 @@
1
  {
2
- "_class_name": "PixNerdPipeline",
3
- "_diffusers_version": "0.30.0",
 
 
 
4
  "scheduler": [
5
- "diffusers_modules.local.scheduling_pixnerd_flow_match",
6
  "PixNerdFlowMatchScheduler"
7
  ],
8
  "transformer": [
9
- "diffusers_modules.local.modeling_pixnerd_transformer_2d",
10
  "PixNerdTransformer2DModel"
11
  ]
12
  }
 
1
  {
2
+ "_class_name": [
3
+ "pipeline",
4
+ "PixNerdPipeline"
5
+ ],
6
+ "_diffusers_version": "0.36.0",
7
  "scheduler": [
8
+ "scheduling_pixnerd_flow_match",
9
  "PixNerdFlowMatchScheduler"
10
  ],
11
  "transformer": [
12
+ "modeling_pixnerd_transformer_2d",
13
  "PixNerdTransformer2DModel"
14
  ]
15
  }
PixNerd-XL-16-256/pipeline.py CHANGED
@@ -1,7 +1,9 @@
1
  from __future__ import annotations
2
 
 
3
  from dataclasses import dataclass
4
- from typing import List, Optional, Sequence, Union
 
5
 
6
  import torch
7
  from diffusers import DiffusionPipeline
@@ -9,10 +11,8 @@ from diffusers.image_processor import VaeImageProcessor
9
  from diffusers.utils import BaseOutput
10
  from PIL import Image
11
 
12
- from .modeling_pixnerd_transformer_2d import PixNerdTransformer2DModel
13
- from .scheduling_pixnerd_flow_match import PixNerdFlowMatchScheduler
14
-
15
  ConditioningInput = Union[str, int, Sequence[Union[str, int]]]
 
16
 
17
 
18
  @dataclass
@@ -27,9 +27,11 @@ class PixNerdPipeline(DiffusionPipeline):
27
  def __init__(
28
  self,
29
  transformer,
30
- scheduler: PixNerdFlowMatchScheduler,
31
  vae=None,
32
  conditioner=None,
 
 
33
  ):
34
  super().__init__()
35
  if vae is None:
@@ -46,6 +48,170 @@ class PixNerdPipeline(DiffusionPipeline):
46
  )
47
  self.image_processor = VaeImageProcessor(vae_scale_factor=1)
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  @staticmethod
50
  def _fp_to_uint8(image: torch.Tensor) -> torch.Tensor:
51
  return torch.clip_((image + 1) * 127.5 + 0.5, 0, 255).to(torch.uint8)
@@ -71,10 +237,11 @@ class PixNerdPipeline(DiffusionPipeline):
71
  num_images_per_prompt: int,
72
  ):
73
  prompts = self._repeat(self._to_list(prompt), num_images_per_prompt)
 
74
  metadata = {"device": self._execution_device}
75
  with torch.no_grad():
76
- cond, uncond = self.conditioner(prompts, metadata)
77
- return cond, uncond, prompts
78
 
79
  def prepare_latents(
80
  self,
@@ -124,9 +291,10 @@ class PixNerdPipeline(DiffusionPipeline):
124
  cond, default_uncond, prompts = self.encode_prompt(prompt, num_images_per_prompt)
125
  if negative_prompt is not None:
126
  negative = self._repeat(self._to_list(negative_prompt), num_images_per_prompt)
 
127
  metadata = {"device": self._execution_device}
128
  with torch.no_grad():
129
- _, uncond = self.conditioner(negative, metadata)
130
  else:
131
  uncond = default_uncond
132
  batch_size = len(prompts)
@@ -178,6 +346,7 @@ class PixNerdPipeline(DiffusionPipeline):
178
  return (output,)
179
  return PixNerdPipelineOutput(images=output)
180
 
 
181
  __all__ = [
182
  "PixNerdPipeline",
183
  "PixNerdPipelineOutput",
 
1
  from __future__ import annotations
2
 
3
+ import sys
4
  from dataclasses import dataclass
5
+ from pathlib import Path
6
+ from typing import List, Literal, Optional, Sequence, Union
7
 
8
  import torch
9
  from diffusers import DiffusionPipeline
 
11
  from diffusers.utils import BaseOutput
12
  from PIL import Image
13
 
 
 
 
14
  ConditioningInput = Union[str, int, Sequence[Union[str, int]]]
15
+ Language = Literal["en", "cn"]
16
 
17
 
18
  @dataclass
 
27
  def __init__(
28
  self,
29
  transformer,
30
+ scheduler,
31
  vae=None,
32
  conditioner=None,
33
+ id2label: Optional[dict[int, str]] = None,
34
+ id2label_cn: Optional[dict[int, str]] = None,
35
  ):
36
  super().__init__()
37
  if vae is None:
 
48
  )
49
  self.image_processor = VaeImageProcessor(vae_scale_factor=1)
50
 
51
+ if id2label is None and id2label_cn is None:
52
+ id2label, id2label_cn = self._load_repo_labels()
53
+ self._id2label = id2label or {}
54
+ self._id2label_cn = id2label_cn or {}
55
+ self.labels = self._build_label2id(self._id2label)
56
+ self.labels_cn = self._build_label2id(self._id2label_cn)
57
+ self._labels_loaded_from_path = bool(self._id2label or self._id2label_cn)
58
+
59
+ def _ensure_labels_loaded(self) -> None:
60
+ if self._labels_loaded_from_path:
61
+ return
62
+
63
+ path = getattr(getattr(self, "config", None), "_name_or_path", None) or getattr(self, "_name_or_path", None)
64
+ if not path:
65
+ return
66
+
67
+ id2label, id2label_cn = self._load_labels_for_path(path)
68
+ if id2label is None and id2label_cn is None:
69
+ self._labels_loaded_from_path = True
70
+ return
71
+
72
+ self._id2label = id2label or {}
73
+ self._id2label_cn = id2label_cn or {}
74
+ self.labels = self._build_label2id(self._id2label)
75
+ self.labels_cn = self._build_label2id(self._id2label_cn)
76
+ self._labels_loaded_from_path = True
77
+
78
+ @staticmethod
79
+ def _resolve_labels_dir(pretrained_model_name_or_path: Union[str, Path]) -> Optional[Path]:
80
+ path = Path(pretrained_model_name_or_path)
81
+ if not path.exists():
82
+ try:
83
+ from huggingface_hub import snapshot_download
84
+
85
+ path = Path(snapshot_download(pretrained_model_name_or_path))
86
+ except Exception:
87
+ return None
88
+
89
+ if (path / "model_index.json").exists():
90
+ labels_dir = path.parent / "labels"
91
+ else:
92
+ labels_dir = path / "labels"
93
+ return labels_dir if labels_dir.is_dir() else None
94
+
95
+ @classmethod
96
+ def _load_labels_for_path(
97
+ cls,
98
+ pretrained_model_name_or_path: Union[str, Path],
99
+ ) -> tuple[Optional[dict[int, str]], Optional[dict[int, str]]]:
100
+ labels_dir = cls._resolve_labels_dir(pretrained_model_name_or_path)
101
+ if labels_dir is None:
102
+ return None, None
103
+
104
+ labels_path = str(labels_dir)
105
+ inserted = False
106
+ if labels_path not in sys.path:
107
+ sys.path.insert(0, labels_path)
108
+ inserted = True
109
+ try:
110
+ from imagenet_labels import load_id2label
111
+
112
+ return (
113
+ load_id2label(labels_dir, lang="en"),
114
+ load_id2label(labels_dir, lang="cn"),
115
+ )
116
+ finally:
117
+ if inserted and labels_path in sys.path:
118
+ sys.path.remove(labels_path)
119
+
120
+ @staticmethod
121
+ def _load_repo_labels() -> tuple[Optional[dict[int, str]], Optional[dict[int, str]]]:
122
+ labels_dir = Path(__file__).resolve().parent.parent / "labels"
123
+ if not labels_dir.is_dir():
124
+ return None, None
125
+
126
+ labels_path = str(labels_dir)
127
+ inserted = False
128
+ if labels_path not in sys.path:
129
+ sys.path.insert(0, labels_path)
130
+ inserted = True
131
+ try:
132
+ from imagenet_labels import load_id2label
133
+
134
+ return (
135
+ load_id2label(labels_dir, lang="en"),
136
+ load_id2label(labels_dir, lang="cn"),
137
+ )
138
+ finally:
139
+ if inserted and labels_path in sys.path:
140
+ sys.path.remove(labels_path)
141
+
142
+ @classmethod
143
+ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
144
+ pipe = super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
145
+ id2label, id2label_cn = cls._load_labels_for_path(pretrained_model_name_or_path)
146
+ if id2label is not None or id2label_cn is not None:
147
+ pipe._id2label = id2label or {}
148
+ pipe._id2label_cn = id2label_cn or {}
149
+ pipe.labels = cls._build_label2id(pipe._id2label)
150
+ pipe.labels_cn = cls._build_label2id(pipe._id2label_cn)
151
+ return pipe
152
+
153
+ @staticmethod
154
+ def _build_label2id(id2label: dict[int, str]) -> dict[str, int]:
155
+ label2id: dict[str, int] = {}
156
+ for class_id, value in id2label.items():
157
+ for synonym in value.split(","):
158
+ synonym = synonym.strip()
159
+ if synonym:
160
+ label2id[synonym] = int(class_id)
161
+ return dict(sorted(label2id.items()))
162
+
163
+ @property
164
+ def id2label(self) -> dict[int, str]:
165
+ self._ensure_labels_loaded()
166
+ return self._id2label
167
+
168
+ @property
169
+ def id2label_cn(self) -> dict[int, str]:
170
+ self._ensure_labels_loaded()
171
+ return self._id2label_cn
172
+
173
+ def get_label_ids(
174
+ self,
175
+ labels: Union[str, List[str]],
176
+ *,
177
+ lang: Language = "en",
178
+ ) -> List[int]:
179
+ self._ensure_labels_loaded()
180
+ if isinstance(labels, str):
181
+ labels = [labels]
182
+
183
+ label2id = self.labels if lang == "en" else self.labels_cn
184
+ if not label2id:
185
+ raise ValueError(
186
+ f"No {lang} labels loaded. Ensure `labels/id2label_{lang}.json` exists next to the variant folder."
187
+ )
188
+
189
+ missing = [label for label in labels if label not in label2id]
190
+ if missing:
191
+ preview = ", ".join(list(label2id.keys())[:8])
192
+ raise ValueError(
193
+ f"Unknown label(s) for lang={lang!r}: {missing}. Example valid labels: {preview}, ..."
194
+ )
195
+ return [label2id[label] for label in labels]
196
+
197
+ def _resolve_prompt_item(self, value: Union[str, int]) -> int:
198
+ if isinstance(value, int):
199
+ return value
200
+ if value.isdigit():
201
+ return int(value)
202
+ if value in self.labels:
203
+ return self.labels[value]
204
+ if value in self.labels_cn:
205
+ return self.labels_cn[value]
206
+ raise ValueError(
207
+ f"Unknown class label {value!r}. Pass an ImageNet class id or a synonym from "
208
+ "`pipe.labels` / `pipe.labels_cn`."
209
+ )
210
+
211
+ def _resolve_prompts(self, prompts: List[Union[str, int]]) -> List[int]:
212
+ self._ensure_labels_loaded()
213
+ return [self._resolve_prompt_item(prompt) for prompt in prompts]
214
+
215
  @staticmethod
216
  def _fp_to_uint8(image: torch.Tensor) -> torch.Tensor:
217
  return torch.clip_((image + 1) * 127.5 + 0.5, 0, 255).to(torch.uint8)
 
237
  num_images_per_prompt: int,
238
  ):
239
  prompts = self._repeat(self._to_list(prompt), num_images_per_prompt)
240
+ resolved = self._resolve_prompts(prompts)
241
  metadata = {"device": self._execution_device}
242
  with torch.no_grad():
243
+ cond, uncond = self.conditioner(resolved, metadata)
244
+ return cond, uncond, resolved
245
 
246
  def prepare_latents(
247
  self,
 
291
  cond, default_uncond, prompts = self.encode_prompt(prompt, num_images_per_prompt)
292
  if negative_prompt is not None:
293
  negative = self._repeat(self._to_list(negative_prompt), num_images_per_prompt)
294
+ resolved_negative = self._resolve_prompts(negative)
295
  metadata = {"device": self._execution_device}
296
  with torch.no_grad():
297
+ _, uncond = self.conditioner(resolved_negative, metadata)
298
  else:
299
  uncond = default_uncond
300
  batch_size = len(prompts)
 
346
  return (output,)
347
  return PixNerdPipelineOutput(images=output)
348
 
349
+
350
  __all__ = [
351
  "PixNerdPipeline",
352
  "PixNerdPipelineOutput",
PixNerd-XL-16-256/scheduler/scheduling_pixnerd_flow_match.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any, Dict, List, Optional, Tuple, Union
5
+
6
+ import torch
7
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
8
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
9
+ from diffusers.utils import BaseOutput
10
+
11
+ @dataclass
12
+ class PixNerdSchedulerOutput(BaseOutput):
13
+ prev_sample: torch.Tensor
14
+
15
+
16
+ class PixNerdFlowMatchScheduler(SchedulerMixin, ConfigMixin):
17
+ """
18
+ Diffusers-compatible scheduler wrapper for PixNerd's AdamLM flow-matching sampler.
19
+ """
20
+
21
+ config_name = "scheduler_config.json"
22
+ order = 1
23
+ init_noise_sigma = 1.0
24
+
25
+ @staticmethod
26
+ def _lagrange_coeffs(order: int, pre_ts: torch.Tensor, t_start: torch.Tensor, t_end: torch.Tensor) -> List[float]:
27
+ ts = [float(v) for v in pre_ts[-order:].tolist()]
28
+ a = float(t_start)
29
+ b = float(t_end)
30
+
31
+ if order == 1:
32
+ return [1.0]
33
+ if order == 2:
34
+ t1, t2 = ts
35
+ int1 = 0.5 / (t1 - t2) * ((b - t2) ** 2 - (a - t2) ** 2)
36
+ int2 = 0.5 / (t2 - t1) * ((b - t1) ** 2 - (a - t1) ** 2)
37
+ total = int1 + int2
38
+ return [int1 / total, int2 / total]
39
+ if order == 3:
40
+ t1, t2, t3 = ts
41
+ int1_denom = (t1 - t2) * (t1 - t3)
42
+ int1 = ((1 / 3) * b**3 - 0.5 * (t2 + t3) * b**2 + (t2 * t3) * b) - (
43
+ (1 / 3) * a**3 - 0.5 * (t2 + t3) * a**2 + (t2 * t3) * a
44
+ )
45
+ int1 = int1 / int1_denom
46
+ int2_denom = (t2 - t1) * (t2 - t3)
47
+ int2 = ((1 / 3) * b**3 - 0.5 * (t1 + t3) * b**2 + (t1 * t3) * b) - (
48
+ (1 / 3) * a**3 - 0.5 * (t1 + t3) * a**2 + (t1 * t3) * a
49
+ )
50
+ int2 = int2 / int2_denom
51
+ int3_denom = (t3 - t1) * (t3 - t2)
52
+ int3 = ((1 / 3) * b**3 - 0.5 * (t1 + t2) * b**2 + (t1 * t2) * b) - (
53
+ (1 / 3) * a**3 - 0.5 * (t1 + t2) * a**2 + (t1 * t2) * a
54
+ )
55
+ int3 = int3 / int3_denom
56
+ total = int1 + int2 + int3
57
+ return [int1 / total, int2 / total, int3 / total]
58
+ if order == 4:
59
+ t1, t2, t3, t4 = ts
60
+ int1_denom = (t1 - t2) * (t1 - t3) * (t1 - t4)
61
+ int1 = ((1 / 4) * b**4 - (1 / 3) * (t2 + t3 + t4) * b**3 + 0.5 * (t3 * t4 + t2 * t3 + t2 * t4) * b**2 - (t2 * t3 * t4) * b) - (
62
+ (1 / 4) * a**4 - (1 / 3) * (t2 + t3 + t4) * a**3 + 0.5 * (t3 * t4 + t2 * t3 + t2 * t4) * a**2 - (t2 * t3 * t4) * a
63
+ )
64
+ int1 = int1 / int1_denom
65
+ int2_denom = (t2 - t1) * (t2 - t3) * (t2 - t4)
66
+ int2 = ((1 / 4) * b**4 - (1 / 3) * (t1 + t3 + t4) * b**3 + 0.5 * (t3 * t4 + t1 * t3 + t1 * t4) * b**2 - (t1 * t3 * t4) * b) - (
67
+ (1 / 4) * a**4 - (1 / 3) * (t1 + t3 + t4) * a**3 + 0.5 * (t3 * t4 + t1 * t3 + t1 * t4) * a**2 - (t1 * t3 * t4) * a
68
+ )
69
+ int2 = int2 / int2_denom
70
+ int3_denom = (t3 - t1) * (t3 - t2) * (t3 - t4)
71
+ int3 = ((1 / 4) * b**4 - (1 / 3) * (t1 + t2 + t4) * b**3 + 0.5 * (t4 * t2 + t1 * t2 + t1 * t4) * b**2 - (t1 * t2 * t4) * b) - (
72
+ (1 / 4) * a**4 - (1 / 3) * (t1 + t2 + t4) * a**3 + 0.5 * (t4 * t2 + t1 * t2 + t1 * t4) * a**2 - (t1 * t2 * t4) * a
73
+ )
74
+ int3 = int3 / int3_denom
75
+ int4_denom = (t4 - t1) * (t4 - t2) * (t4 - t3)
76
+ int4 = ((1 / 4) * b**4 - (1 / 3) * (t1 + t2 + t3) * b**3 + 0.5 * (t3 * t2 + t1 * t2 + t1 * t3) * b**2 - (t1 * t2 * t3) * b) - (
77
+ (1 / 4) * a**4 - (1 / 3) * (t1 + t2 + t3) * a**3 + 0.5 * (t3 * t2 + t1 * t2 + t1 * t3) * a**2 - (t1 * t2 * t3) * a
78
+ )
79
+ int4 = int4 / int4_denom
80
+ total = int1 + int2 + int3 + int4
81
+ return [int1 / total, int2 / total, int3 / total, int4 / total]
82
+ raise ValueError(f"Unsupported solver order: {order}.")
83
+
84
+ @register_to_config
85
+ def __init__(
86
+ self,
87
+ num_train_timesteps: int = 1000,
88
+ num_inference_steps: int = 25,
89
+ guidance_scale: float = 4.0,
90
+ timeshift: float = 3.0,
91
+ order: int = 2,
92
+ guidance_interval_min: float = 0.0,
93
+ guidance_interval_max: float = 1.0,
94
+ last_step: Optional[float] = None,
95
+ ) -> None:
96
+ self.num_inference_steps = int(num_inference_steps)
97
+ self.guidance_scale = float(guidance_scale)
98
+ self.timeshift = float(timeshift)
99
+ self.order = int(order)
100
+ self.guidance_interval_min = float(guidance_interval_min)
101
+ self.guidance_interval_max = float(guidance_interval_max)
102
+ self.last_step = last_step
103
+ self._reset_state()
104
+
105
+ @classmethod
106
+ def from_sampler_spec(cls, sampler_spec: Dict[str, Any]) -> "PixNerdFlowMatchScheduler":
107
+ init_args = dict(sampler_spec.get("init_args", {}))
108
+ return cls(
109
+ num_inference_steps=int(init_args.get("num_steps", 25)),
110
+ guidance_scale=float(init_args.get("guidance", 4.0)),
111
+ timeshift=float(init_args.get("timeshift", 3.0)),
112
+ order=int(init_args.get("order", 2)),
113
+ guidance_interval_min=float(init_args.get("guidance_interval_min", 0.0)),
114
+ guidance_interval_max=float(init_args.get("guidance_interval_max", 1.0)),
115
+ last_step=init_args.get("last_step"),
116
+ )
117
+
118
+ def _reset_state(self) -> None:
119
+ self.timesteps: Optional[torch.Tensor] = None
120
+ self._timedeltas: Optional[torch.Tensor] = None
121
+ self._solver_coeffs = None
122
+ self._model_outputs = []
123
+ self._step_index = 0
124
+
125
+ @staticmethod
126
+ def _shift_respace_fn(t: torch.Tensor, shift: float = 3.0) -> torch.Tensor:
127
+ return t / (t + (1 - t) * shift)
128
+
129
+ def _build_solver_state(
130
+ self,
131
+ num_inference_steps: int,
132
+ timeshift: float,
133
+ device: Optional[Union[str, torch.device]] = None,
134
+ ) -> Tuple[torch.Tensor, torch.Tensor, List[List[float]]]:
135
+ last_step = self.last_step
136
+ if last_step is None:
137
+ last_step = 1.0 / float(num_inference_steps)
138
+
139
+ endpoints = torch.linspace(0.0, 1 - float(last_step), int(num_inference_steps), dtype=torch.float32)
140
+ endpoints = torch.cat([endpoints, torch.tensor([1.0], dtype=torch.float32)], dim=0)
141
+ timesteps = self._shift_respace_fn(endpoints, timeshift).to(device=device)
142
+ timedeltas = (timesteps[1:] - timesteps[:-1]).to(device=device)
143
+
144
+ solver_coeffs: List[List[float]] = [[] for _ in range(int(num_inference_steps))]
145
+ for i in range(int(num_inference_steps)):
146
+ order = min(self.order, i + 1)
147
+ pre_ts = timesteps[: i + 1]
148
+ coeffs = self._lagrange_coeffs(order, pre_ts, pre_ts[i], timesteps[i + 1])
149
+ solver_coeffs[i] = coeffs
150
+ return timesteps[:-1], timedeltas, solver_coeffs
151
+
152
+ def set_timesteps(
153
+ self,
154
+ num_inference_steps: Optional[int] = None,
155
+ device: Optional[Union[str, torch.device]] = None,
156
+ timeshift: Optional[float] = None,
157
+ guidance_scale: Optional[float] = None,
158
+ order: Optional[int] = None,
159
+ **kwargs: Any,
160
+ ) -> None:
161
+ if num_inference_steps is not None:
162
+ self.num_inference_steps = int(num_inference_steps)
163
+ if timeshift is not None:
164
+ self.timeshift = float(timeshift)
165
+ if guidance_scale is not None:
166
+ self.guidance_scale = float(guidance_scale)
167
+ if order is not None:
168
+ self.order = int(order)
169
+
170
+ timesteps, timedeltas, solver_coeffs = self._build_solver_state(
171
+ self.num_inference_steps,
172
+ self.timeshift,
173
+ device=device,
174
+ )
175
+ self.timesteps = timesteps
176
+ self._timedeltas = timedeltas
177
+ self._solver_coeffs = solver_coeffs
178
+ self._model_outputs = []
179
+ self._step_index = 0
180
+
181
+ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[torch.Tensor] = None) -> torch.Tensor:
182
+ return sample
183
+
184
+ def classifier_free_guidance(self, model_output: torch.Tensor) -> torch.Tensor:
185
+ if model_output.shape[0] % 2 != 0:
186
+ raise ValueError("Classifier-free guidance expects concatenated unconditional/conditional batches.")
187
+ uncond, cond = model_output.chunk(2, dim=0)
188
+ return uncond + self.guidance_scale * (cond - uncond)
189
+
190
+ def step(
191
+ self,
192
+ model_output: torch.Tensor,
193
+ timestep: Union[torch.Tensor, float, int],
194
+ sample: torch.Tensor,
195
+ return_dict: bool = True,
196
+ **kwargs: Any,
197
+ ) -> Union[PixNerdSchedulerOutput, Tuple[torch.Tensor]]:
198
+ if self.timesteps is None or self._timedeltas is None or self._solver_coeffs is None:
199
+ raise RuntimeError("`set_timesteps` must be called before `step`.")
200
+ if self._step_index >= len(self._solver_coeffs):
201
+ raise RuntimeError("Scheduler step index exceeded configured timesteps.")
202
+
203
+ coeffs = self._solver_coeffs[self._step_index]
204
+ self._model_outputs.append(model_output)
205
+ order = len(coeffs)
206
+ pred = torch.zeros_like(model_output)
207
+ recent = self._model_outputs[-order:]
208
+ for coeff, output in zip(coeffs, recent):
209
+ pred = pred + coeff * output
210
+
211
+ prev_sample = sample + pred * self._timedeltas[self._step_index]
212
+ self._step_index += 1
213
+
214
+ if not return_dict:
215
+ return (prev_sample,)
216
+ return PixNerdSchedulerOutput(prev_sample=prev_sample)
217
+
218
+ def add_noise(
219
+ self,
220
+ original_samples: torch.Tensor,
221
+ noise: torch.Tensor,
222
+ timesteps: torch.Tensor,
223
+ ) -> torch.Tensor:
224
+ alpha = timesteps.view(-1, 1, 1, 1)
225
+ sigma = (1.0 - timesteps).view(-1, 1, 1, 1)
226
+ return alpha * original_samples + sigma * noise
227
+
228
+ __all__ = [
229
+ "PixNerdFlowMatchScheduler",
230
+ "PixNerdSchedulerOutput",
231
+ ]
PixNerd-XL-16-256/transformer/config.json CHANGED
@@ -3,13 +3,13 @@
3
  "_diffusers_version": "0.36.0",
4
  "compile_denoiser": false,
5
  "conditioner_spec": {
6
- "class_path": "diffusers_modules.local.modeling_pixnerd_transformer_2d.LabelConditioner",
7
  "init_args": {
8
  "num_classes": 1000
9
  }
10
  },
11
  "denoiser_spec": {
12
- "class_path": "diffusers_modules.local.modeling_pixnerd_transformer_2d.PixNerDiT",
13
  "init_args": {
14
  "hidden_size": 1152,
15
  "hidden_size_x": 64,
@@ -26,7 +26,7 @@
26
  "ema_decay": 0.9999,
27
  "use_ema": true,
28
  "vae_spec": {
29
- "class_path": "diffusers_modules.local.modeling_pixnerd_transformer_2d.PixelAE",
30
  "init_args": {
31
  "scale": 1.0
32
  }
 
3
  "_diffusers_version": "0.36.0",
4
  "compile_denoiser": false,
5
  "conditioner_spec": {
6
+ "class_path": "modeling_pixnerd_transformer_2d.LabelConditioner",
7
  "init_args": {
8
  "num_classes": 1000
9
  }
10
  },
11
  "denoiser_spec": {
12
+ "class_path": "modeling_pixnerd_transformer_2d.PixNerDiT",
13
  "init_args": {
14
  "hidden_size": 1152,
15
  "hidden_size_x": 64,
 
26
  "ema_decay": 0.9999,
27
  "use_ema": true,
28
  "vae_spec": {
29
+ "class_path": "modeling_pixnerd_transformer_2d.PixelAE",
30
  "init_args": {
31
  "scale": 1.0
32
  }
PixNerd-XL-16-256/transformer/modeling_pixnerd_transformer_2d.py ADDED
@@ -0,0 +1,749 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import copy
4
+ import importlib
5
+ import math
6
+ import sys
7
+ from dataclasses import dataclass
8
+ from functools import lru_cache
9
+ from typing import Any, Dict, Iterable, List, Optional, Tuple
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
14
+ from diffusers.models.modeling_utils import ModelMixin
15
+ from diffusers.utils import BaseOutput
16
+ from torch.nn.functional import scaled_dot_product_attention
17
+
18
+ class BaseAE(torch.nn.Module):
19
+ def __init__(self, scale=1.0, shift=0.0):
20
+ super().__init__()
21
+ self.scale = scale
22
+ self.shift = shift
23
+
24
+ def encode(self, x):
25
+ return self._impl_encode(x) #.to(torch.bfloat16)
26
+
27
+ # @torch.autocast("cuda", dtype=torch.bfloat16)
28
+ def decode(self, x):
29
+ return self._impl_decode(x) #.to(torch.bfloat16)
30
+
31
+ def _impl_encode(self, x):
32
+ raise NotImplementedError
33
+
34
+ def _impl_decode(self, x):
35
+ raise NotImplementedError
36
+
37
+ def uint82fp(x):
38
+ x = x.to(torch.float32)
39
+ x = (x - 127.5) / 127.5
40
+ return x
41
+
42
+ def fp2uint8(x):
43
+ x = torch.clip_((x + 1) * 127.5 + 0.5, 0, 255).to(torch.uint8)
44
+ return x
45
+
46
+
47
+ class PixelAE(BaseAE):
48
+ def __init__(self, scale=1.0, shift=0.0):
49
+ super().__init__(scale, shift)
50
+
51
+ def _impl_encode(self, x):
52
+ return x/self.scale+self.shift
53
+
54
+ def _impl_decode(self, x):
55
+ return (x-self.shift)*self.scale
56
+
57
+
58
+ def resolve_conditioner_device(metadata: dict, fallback: torch.device | None = None) -> torch.device:
59
+ if metadata is None:
60
+ metadata = {}
61
+ if "device" in metadata and metadata["device"] is not None:
62
+ return torch.device(metadata["device"])
63
+ if fallback is not None:
64
+ return fallback
65
+ return torch.device("cuda" if torch.cuda.is_available() else "cpu")
66
+
67
+
68
+ class BaseConditioner(nn.Module):
69
+ def __init__(self):
70
+ super(BaseConditioner, self).__init__()
71
+
72
+ def _impl_condition(self, y, metadata)->torch.Tensor:
73
+ raise NotImplementedError()
74
+
75
+ def _impl_uncondition(self, y, metadata)->torch.Tensor:
76
+ raise NotImplementedError()
77
+
78
+ @torch.no_grad()
79
+ def __call__(self, y, metadata:dict={}):
80
+ condition = self._impl_condition(y, metadata)
81
+ uncondition = self._impl_uncondition(y, metadata)
82
+ if condition.dtype in [torch.float64, torch.float32, torch.float16]:
83
+ condition = condition.to(torch.bfloat16)
84
+ if uncondition.dtype in [torch.float64,torch.float32, torch.float16]:
85
+ uncondition = uncondition.to(torch.bfloat16)
86
+ return condition, uncondition
87
+
88
+
89
+ class ComposeConditioner(BaseConditioner):
90
+ def __init__(self, conditioners:List[BaseConditioner]):
91
+ super().__init__()
92
+ self.conditioners = conditioners
93
+
94
+ def _impl_condition(self, y, metadata):
95
+ condition = []
96
+ for conditioner in self.conditioners:
97
+ condition.append(conditioner._impl_condition(y, metadata))
98
+ condition = torch.cat(condition, dim=1)
99
+ return condition
100
+
101
+ def _impl_uncondition(self, y, metadata):
102
+ uncondition = []
103
+ for conditioner in self.conditioners:
104
+ uncondition.append(conditioner._impl_uncondition(y, metadata))
105
+ uncondition = torch.cat(uncondition, dim=1)
106
+ return uncondition
107
+
108
+
109
+ class LabelConditioner(BaseConditioner):
110
+ def __init__(self, num_classes):
111
+ super().__init__()
112
+ self.null_condition = num_classes
113
+
114
+ def _impl_condition(self, y, metadata):
115
+ device = resolve_conditioner_device(metadata)
116
+ return torch.tensor(y, device=device).long()
117
+
118
+ def _impl_uncondition(self, y, metadata):
119
+ device = resolve_conditioner_device(metadata)
120
+ return torch.full((len(y),), self.null_condition, dtype=torch.long, device=device)
121
+
122
+
123
+ def modulate(x, shift, scale):
124
+ return x * (1 + scale) + shift
125
+
126
+ class Embed(nn.Module):
127
+ def __init__(
128
+ self,
129
+ in_chans: int = 3,
130
+ embed_dim: int = 768,
131
+ norm_layer = None,
132
+ bias: bool = True,
133
+ ):
134
+ super().__init__()
135
+ self.in_chans = in_chans
136
+ self.embed_dim = embed_dim
137
+ self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
138
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
139
+ def forward(self, x):
140
+ x = self.proj(x)
141
+ x = self.norm(x)
142
+ return x
143
+
144
+ class TimestepEmbedder(nn.Module):
145
+
146
+ def __init__(self, hidden_size, frequency_embedding_size=256):
147
+ super().__init__()
148
+ self.mlp = nn.Sequential(
149
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
150
+ nn.SiLU(),
151
+ nn.Linear(hidden_size, hidden_size, bias=True),
152
+ )
153
+ self.frequency_embedding_size = frequency_embedding_size
154
+
155
+ @staticmethod
156
+ def timestep_embedding(t, dim, max_period=10):
157
+ half = dim // 2
158
+ freqs = torch.exp(
159
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
160
+ )
161
+ args = t[..., None].float() * freqs[None, ...]
162
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
163
+ if dim % 2:
164
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
165
+ return embedding
166
+
167
+ def forward(self, t):
168
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
169
+ t_emb = self.mlp(t_freq)
170
+ return t_emb
171
+
172
+ class LabelEmbedder(nn.Module):
173
+ def __init__(self, num_classes, hidden_size):
174
+ super().__init__()
175
+ self.embedding_table = nn.Embedding(num_classes, hidden_size)
176
+ self.num_classes = num_classes
177
+
178
+ def forward(self, labels,):
179
+ embeddings = self.embedding_table(labels)
180
+ return embeddings
181
+
182
+ class FinalLayer(nn.Module):
183
+ def __init__(self, hidden_size, out_channels):
184
+ super().__init__()
185
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
186
+ self.linear = nn.Linear(hidden_size, out_channels, bias=True)
187
+ self.adaLN_modulation = nn.Sequential(
188
+ nn.Linear(hidden_size, 2*hidden_size, bias=True)
189
+ )
190
+
191
+ def forward(self, x, c):
192
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
193
+ x = modulate(self.norm_final(x), shift, scale)
194
+ x = self.linear(x)
195
+ return x
196
+
197
+ class RMSNorm(nn.Module):
198
+ def __init__(self, hidden_size, eps=1e-6):
199
+ """
200
+ LlamaRMSNorm is equivalent to T5LayerNorm
201
+ """
202
+ super().__init__()
203
+ self.weight = nn.Parameter(torch.ones(hidden_size))
204
+ self.variance_epsilon = eps
205
+
206
+ def forward(self, hidden_states):
207
+ input_dtype = hidden_states.dtype
208
+ hidden_states = hidden_states.to(torch.float32)
209
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
210
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
211
+ return self.weight * hidden_states.to(input_dtype)
212
+
213
+ class FeedForward(nn.Module):
214
+ def __init__(
215
+ self,
216
+ dim: int,
217
+ hidden_dim: int,
218
+ ):
219
+ super().__init__()
220
+ hidden_dim = int(2 * hidden_dim / 3)
221
+ self.w1 = nn.Linear(dim, hidden_dim, bias=False)
222
+ self.w3 = nn.Linear(dim, hidden_dim, bias=False)
223
+ self.w2 = nn.Linear(hidden_dim, dim, bias=False)
224
+ def forward(self, x):
225
+ x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
226
+ return x
227
+
228
+ def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
229
+ # assert H * H == end
230
+ # flat_patch_pos = torch.linspace(-1, 1, end) # N = end
231
+ x_pos = torch.linspace(0, scale, width)
232
+ y_pos = torch.linspace(0, scale, height)
233
+ y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
234
+ y_pos = y_pos.reshape(-1)
235
+ x_pos = x_pos.reshape(-1)
236
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
237
+ x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
238
+ y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
239
+ x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
240
+ y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
241
+ freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
242
+ freqs_cis = freqs_cis.reshape(height*width, -1)
243
+ return freqs_cis
244
+
245
+
246
+ def apply_rotary_emb(
247
+ xq: torch.Tensor,
248
+ xk: torch.Tensor,
249
+ freqs_cis: torch.Tensor,
250
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
251
+ freqs_cis = freqs_cis[None, :, None, :]
252
+ # xq : B N H Hc
253
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
254
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
255
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
256
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
257
+ return xq_out.type_as(xq), xk_out.type_as(xk)
258
+
259
+
260
+ class RAttention(nn.Module):
261
+ def __init__(
262
+ self,
263
+ dim: int,
264
+ num_heads: int = 8,
265
+ qkv_bias: bool = False,
266
+ qk_norm: bool = True,
267
+ attn_drop: float = 0.,
268
+ proj_drop: float = 0.,
269
+ norm_layer: nn.Module = RMSNorm,
270
+ ) -> None:
271
+ super().__init__()
272
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
273
+
274
+ self.dim = dim
275
+ self.num_heads = num_heads
276
+ self.head_dim = dim // num_heads
277
+ self.scale = self.head_dim ** -0.5
278
+
279
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
280
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
281
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
282
+ self.attn_drop = nn.Dropout(attn_drop)
283
+ self.proj = nn.Linear(dim, dim)
284
+ self.proj_drop = nn.Dropout(proj_drop)
285
+
286
+ def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
287
+ B, N, C = x.shape
288
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
289
+ q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
290
+ q = self.q_norm(q)
291
+ k = self.k_norm(k)
292
+ q, k = apply_rotary_emb(q, k, freqs_cis=pos)
293
+ q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
294
+ k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
295
+ v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
296
+
297
+ x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
298
+
299
+ x = x.transpose(1, 2).reshape(B, N, C)
300
+ x = self.proj(x)
301
+ x = self.proj_drop(x)
302
+ return x
303
+
304
+
305
+
306
+ class FlattenDiTBlock(nn.Module):
307
+ def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
308
+ super().__init__()
309
+ self.norm1 = RMSNorm(hidden_size, eps=1e-6)
310
+ self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
311
+ self.norm2 = RMSNorm(hidden_size, eps=1e-6)
312
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
313
+ self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
314
+ self.adaLN_modulation = nn.Sequential(
315
+ nn.Linear(hidden_size, 6 * hidden_size, bias=True)
316
+ )
317
+
318
+ def forward(self, x, c, pos, mask=None):
319
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
320
+ x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
321
+ x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
322
+ return x
323
+
324
+ class NerfEmbedder(nn.Module):
325
+ def __init__(self, in_channels, hidden_size_input, max_freqs):
326
+ super().__init__()
327
+ self.max_freqs = max_freqs
328
+ self.hidden_size_input = hidden_size_input
329
+ self.embedder = nn.Sequential(
330
+ nn.Linear(in_channels+max_freqs**2, hidden_size_input, bias=True),
331
+ )
332
+
333
+ @lru_cache
334
+ def fetch_pos(self, patch_size, device, dtype):
335
+ pos_x = torch.linspace(0, 1, patch_size, device=device, dtype=dtype)
336
+ pos_y = torch.linspace(0, 1, patch_size, device=device, dtype=dtype)
337
+ pos_y, pos_x = torch.meshgrid(pos_y, pos_x, indexing="ij")
338
+ pos_x = pos_x.reshape(-1, 1, 1)
339
+ pos_y = pos_y.reshape(-1, 1, 1)
340
+
341
+ freqs = torch.linspace(0, self.max_freqs, self.max_freqs, dtype=dtype, device=device)
342
+ freqs_x = freqs[None, :, None]
343
+ freqs_y = freqs[None, None, :]
344
+ coeffs = (1 + freqs_x * freqs_y) ** -1
345
+ dct_x = torch.cos(pos_x * freqs_x * torch.pi)
346
+ dct_y = torch.cos(pos_y * freqs_y * torch.pi)
347
+ dct = (dct_x * dct_y * coeffs).view(1, -1, self.max_freqs ** 2)
348
+ return dct
349
+
350
+
351
+ def forward(self, inputs):
352
+ B, P2, C = inputs.shape
353
+ patch_size = int(P2 ** 0.5)
354
+ device = inputs.device
355
+ dtype = inputs.dtype
356
+ dct = self.fetch_pos(patch_size, device, dtype)
357
+ dct = dct.repeat(B, 1, 1)
358
+ inputs = torch.cat([inputs, dct], dim=-1)
359
+ inputs = self.embedder(inputs)
360
+ return inputs
361
+
362
+
363
+ class NerfBlock(nn.Module):
364
+ def __init__(self, hidden_size_s, hidden_size_x, mlp_ratio=4):
365
+ super().__init__()
366
+ self.param_generator1 = nn.Sequential(
367
+ nn.Linear(hidden_size_s, 2*hidden_size_x**2*mlp_ratio, bias=True),
368
+ )
369
+ self.norm = RMSNorm(hidden_size_x, eps=1e-6)
370
+ self.mlp_ratio = mlp_ratio
371
+ def forward(self, x, s):
372
+ batch_size, num_x, hidden_size_x = x.shape
373
+ mlp_params1 = self.param_generator1(s)
374
+ fc1_param1, fc2_param1 = mlp_params1.chunk(2, dim=-1)
375
+ fc1_param1 = fc1_param1.view(batch_size, hidden_size_x, hidden_size_x*self.mlp_ratio)
376
+ fc2_param1 = fc2_param1.view(batch_size, hidden_size_x*self.mlp_ratio, hidden_size_x)
377
+
378
+ # normalize fc1
379
+ normalized_fc1_param1 = torch.nn.functional.normalize(fc1_param1, dim=-2)
380
+ # normalize fc2
381
+ normalized_fc2_param1 = torch.nn.functional.normalize(fc2_param1, dim=-2)
382
+ # mlp 1
383
+ res_x = x
384
+ x = self.norm(x)
385
+ x = torch.bmm(x, normalized_fc1_param1)
386
+ x = torch.nn.functional.silu(x)
387
+ x = torch.bmm(x, normalized_fc2_param1)
388
+ x = x + res_x
389
+ return x
390
+
391
+ class NerfFinalLayer(nn.Module):
392
+ def __init__(self, hidden_size, out_channels):
393
+ super().__init__()
394
+ self.norm = RMSNorm(hidden_size, eps=1e-6)
395
+ self.linear = nn.Linear(hidden_size, out_channels, bias=True)
396
+ def forward(self, x):
397
+ x = self.norm(x)
398
+ x = self.linear(x)
399
+ return x
400
+
401
+ class PixNerDiT(nn.Module):
402
+ def __init__(
403
+ self,
404
+ in_channels=4,
405
+ num_groups=12,
406
+ hidden_size=1152,
407
+ hidden_size_x=64,
408
+ nerf_mlpratio=4,
409
+ num_blocks=18,
410
+ num_cond_blocks=4,
411
+ patch_size=2,
412
+ num_classes=1000,
413
+ learn_sigma=True,
414
+ deep_supervision=0,
415
+ weight_path=None,
416
+ load_ema=False,
417
+ ):
418
+ super().__init__()
419
+ self.deep_supervision = deep_supervision
420
+ self.learn_sigma = learn_sigma
421
+ self.in_channels = in_channels
422
+ self.out_channels = in_channels
423
+ self.hidden_size = hidden_size
424
+ self.num_groups = num_groups
425
+ self.num_blocks = num_blocks
426
+ self.num_cond_blocks = num_cond_blocks
427
+ self.patch_size = patch_size
428
+ self.x_embedder = NerfEmbedder(in_channels, hidden_size_x, max_freqs=8)
429
+ self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
430
+ self.t_embedder = TimestepEmbedder(hidden_size)
431
+ self.y_embedder = LabelEmbedder(num_classes+1, hidden_size)
432
+
433
+ self.final_layer = NerfFinalLayer(hidden_size_x, self.out_channels)
434
+
435
+ self.weight_path = weight_path
436
+
437
+ self.load_ema = load_ema
438
+ self.blocks = nn.ModuleList([
439
+ FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_cond_blocks)
440
+ ])
441
+ self.blocks.extend([
442
+ NerfBlock(self.hidden_size, hidden_size_x, nerf_mlpratio) for _ in range(self.num_cond_blocks, self.num_blocks)
443
+ ])
444
+ self.initialize_weights()
445
+ self.precompute_pos = dict()
446
+
447
+ def fetch_pos(self, height, width, device):
448
+ if (height, width) in self.precompute_pos:
449
+ return self.precompute_pos[(height, width)].to(device)
450
+ else:
451
+ pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
452
+ self.precompute_pos[(height, width)] = pos
453
+ return pos
454
+
455
+ def initialize_weights(self):
456
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
457
+ w = self.s_embedder.proj.weight.data
458
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
459
+ nn.init.constant_(self.s_embedder.proj.bias, 0)
460
+
461
+ # Initialize label embedding table:
462
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
463
+
464
+ # Initialize timestep embedding MLP:
465
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
466
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
467
+
468
+ # zero init final layer
469
+ nn.init.zeros_(self.final_layer.linear.weight)
470
+ nn.init.zeros_(self.final_layer.linear.bias)
471
+
472
+
473
+ def forward(self, x, t, y, s=None, mask=None):
474
+ B, _, H, W = x.shape
475
+ pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
476
+ x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
477
+ t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
478
+ y = self.y_embedder(y).view(B, 1, self.hidden_size)
479
+ c = nn.functional.silu(t + y)
480
+ if s is None:
481
+ s = self.s_embedder(x)
482
+ for i in range(self.num_cond_blocks):
483
+ s = self.blocks[i](s, c, pos, mask)
484
+ s = nn.functional.silu(t + s)
485
+ batch_size, length, _ = s.shape
486
+ x = x.reshape(batch_size*length, self.in_channels, self.patch_size**2)
487
+ x = x.transpose(1, 2)
488
+ s = s.view(batch_size*length, self.hidden_size)
489
+ x = self.x_embedder(x)
490
+ for i in range(self.num_cond_blocks, self.num_blocks):
491
+ x = self.blocks[i](x, s)
492
+ x = self.final_layer(x)
493
+ x = x.transpose(1, 2)
494
+ x = x.reshape(batch_size, length, -1)
495
+ x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size)
496
+ return x
497
+
498
+
499
+ def to_container(config: Any) -> Any:
500
+ if hasattr(config, "items") and not isinstance(config, dict):
501
+ return {k: to_container(v) for k, v in config.items()}
502
+ if isinstance(config, list):
503
+ return [to_container(v) for v in config]
504
+ return config
505
+
506
+
507
+ def load_symbol(path: str) -> Any:
508
+ module_path, name = path.rsplit(".", 1)
509
+ if module_path in {__name__, "modeling_pixnerd_transformer_2d"}:
510
+ return getattr(sys.modules[__name__], name)
511
+ module = importlib.import_module(module_path)
512
+ return getattr(module, name)
513
+
514
+
515
+ def instantiate_from_spec(spec: Any) -> Any:
516
+ spec = to_container(spec)
517
+ if isinstance(spec, dict) and "class_path" in spec:
518
+ class_or_fn = load_symbol(spec["class_path"])
519
+ init_args = spec.get("init_args", {})
520
+ if isinstance(init_args, dict):
521
+ init_args = {k: instantiate_from_spec(v) for k, v in init_args.items()}
522
+ return class_or_fn(**init_args)
523
+ if isinstance(spec, dict):
524
+ return {k: instantiate_from_spec(v) for k, v in spec.items()}
525
+ if isinstance(spec, list):
526
+ return [instantiate_from_spec(v) for v in spec]
527
+ if isinstance(spec, str) and "." in spec:
528
+ try:
529
+ return load_symbol(spec)
530
+ except Exception:
531
+ return spec
532
+ return spec
533
+
534
+
535
+ def clone_spec(spec: Dict[str, Any]) -> Dict[str, Any]:
536
+ return copy.deepcopy(to_container(spec))
537
+
538
+
539
+ def load_prefixed_state_dict(
540
+ module: Optional[torch.nn.Module],
541
+ state_dict: Dict[str, torch.Tensor],
542
+ prefixes: Iterable[str],
543
+ ) -> bool:
544
+ if module is None:
545
+ return False
546
+ for prefix in prefixes:
547
+ subset = {
548
+ key[len(prefix) :]: value
549
+ for key, value in state_dict.items()
550
+ if key.startswith(prefix)
551
+ }
552
+ if subset:
553
+ module.load_state_dict(subset, strict=False)
554
+ return True
555
+ return False
556
+
557
+
558
+ @dataclass
559
+ class PixNerdTransformer2DModelOutput(BaseOutput):
560
+ sample: torch.FloatTensor
561
+
562
+
563
+ class PixNerdTransformer2DModel(ModelMixin, ConfigMixin):
564
+ config_name = "config.json"
565
+
566
+ @register_to_config
567
+ def __init__(
568
+ self,
569
+ denoiser_spec: Dict[str, Any],
570
+ conditioner_spec: Dict[str, Any],
571
+ vae_spec: Optional[Dict[str, Any]] = None,
572
+ diffusion_trainer_spec: Optional[Dict[str, Any]] = None,
573
+ use_ema: bool = True,
574
+ ema_decay: float = 0.9999,
575
+ compile_denoiser: bool = False,
576
+ ) -> None:
577
+ super().__init__()
578
+ self.denoiser = instantiate_from_spec(to_container(denoiser_spec))
579
+ self.conditioner = instantiate_from_spec(to_container(conditioner_spec))
580
+ self.vae = instantiate_from_spec(to_container(vae_spec)) if vae_spec is not None else None
581
+ self.diffusion_trainer = (
582
+ instantiate_from_spec(to_container(diffusion_trainer_spec))
583
+ if diffusion_trainer_spec is not None
584
+ else None
585
+ )
586
+
587
+ self.use_ema = bool(use_ema)
588
+ self.ema_decay = float(ema_decay)
589
+ self.ema_denoiser = copy.deepcopy(self.denoiser) if self.use_ema else None
590
+ if self.ema_denoiser is not None:
591
+ self.ema_denoiser.to(torch.float32)
592
+
593
+ if compile_denoiser and hasattr(self.denoiser, "compile"):
594
+ self.denoiser.compile()
595
+ if self.ema_denoiser is not None:
596
+ self.ema_denoiser.compile()
597
+
598
+ self._freeze_non_trainable_modules()
599
+ if self.ema_denoiser is not None:
600
+ self.sync_ema()
601
+
602
+ @property
603
+ def patch_size(self) -> int:
604
+ return int(getattr(self.denoiser, "patch_size", 1))
605
+
606
+ @property
607
+ def in_channels(self) -> int:
608
+ return int(getattr(self.denoiser, "in_channels", 3))
609
+
610
+ @classmethod
611
+ def from_project_config(
612
+ cls,
613
+ model_config: Dict[str, Any],
614
+ use_ema: bool = True,
615
+ compile_denoiser: bool = False,
616
+ ) -> "PixNerdTransformer2DModel":
617
+ model_config = to_container(model_config)
618
+ ema_decay = model_config.get("ema_tracker", {}).get("init_args", {}).get("decay", 0.9999)
619
+ return cls(
620
+ denoiser_spec=model_config["denoiser"],
621
+ conditioner_spec=model_config["conditioner"],
622
+ vae_spec=model_config.get("vae"),
623
+ diffusion_trainer_spec=model_config.get("diffusion_trainer"),
624
+ use_ema=use_ema,
625
+ ema_decay=ema_decay,
626
+ compile_denoiser=compile_denoiser,
627
+ )
628
+
629
+ @staticmethod
630
+ def _as_timestep_tensor(
631
+ timestep: Any,
632
+ batch_size: int,
633
+ device: torch.device,
634
+ ) -> torch.Tensor:
635
+ if isinstance(timestep, torch.Tensor):
636
+ if timestep.ndim == 0:
637
+ return timestep.repeat(batch_size).to(device=device, dtype=torch.float32)
638
+ return timestep.to(device=device, dtype=torch.float32)
639
+ return torch.full((batch_size,), float(timestep), device=device, dtype=torch.float32)
640
+
641
+ def _freeze_module(self, module: Optional[torch.nn.Module]) -> None:
642
+ if module is None:
643
+ return
644
+ module.eval()
645
+ for parameter in module.parameters():
646
+ parameter.requires_grad = False
647
+
648
+ def _freeze_non_trainable_modules(self) -> None:
649
+ self._freeze_module(self.conditioner)
650
+ self._freeze_module(self.vae)
651
+ self._freeze_module(self.ema_denoiser)
652
+
653
+ def forward(
654
+ self,
655
+ sample: torch.Tensor,
656
+ timestep: Any,
657
+ encoder_hidden_states: torch.Tensor,
658
+ return_dict: bool = True,
659
+ ) -> PixNerdTransformer2DModelOutput | Tuple[torch.Tensor]:
660
+ t = self._as_timestep_tensor(timestep, sample.shape[0], sample.device)
661
+ out = self.denoiser(sample, t, encoder_hidden_states)
662
+ if not return_dict:
663
+ return (out,)
664
+ return PixNerdTransformer2DModelOutput(sample=out)
665
+
666
+ def predict_noise(
667
+ self,
668
+ sample: torch.Tensor,
669
+ timestep: Any,
670
+ encoder_hidden_states: torch.Tensor,
671
+ use_ema: bool = False,
672
+ ) -> torch.Tensor:
673
+ t = self._as_timestep_tensor(timestep, sample.shape[0], sample.device)
674
+ denoiser = self.get_inference_denoiser(use_ema=use_ema)
675
+ return denoiser(sample, t, encoder_hidden_states)
676
+
677
+ def get_inference_denoiser(self, use_ema: bool = True) -> torch.nn.Module:
678
+ if use_ema and self.ema_denoiser is not None:
679
+ return self.ema_denoiser
680
+ return self.denoiser
681
+
682
+ @torch.no_grad()
683
+ def get_conditioning(
684
+ self,
685
+ y: Iterable[Any],
686
+ metadata: Optional[Dict[str, Any]] = None,
687
+ ):
688
+ metadata = {} if metadata is None else metadata
689
+ return self.conditioner(y, metadata)
690
+
691
+ @torch.no_grad()
692
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
693
+ if self.vae is None:
694
+ return x
695
+ return self.vae.encode(x)
696
+
697
+ @torch.no_grad()
698
+ def decode(self, latents: torch.Tensor) -> torch.Tensor:
699
+ if self.vae is None:
700
+ return latents
701
+ return self.vae.decode(latents)
702
+
703
+ @torch.no_grad()
704
+ def sync_ema(self) -> None:
705
+ if self.ema_denoiser is None:
706
+ return
707
+ self.ema_denoiser.load_state_dict(self.denoiser.state_dict(), strict=True)
708
+ self.ema_denoiser.to(torch.float32)
709
+
710
+ @torch.no_grad()
711
+ def ema_step(self, decay: Optional[float] = None) -> None:
712
+ if self.ema_denoiser is None:
713
+ return
714
+ decay = self.ema_decay if decay is None else float(decay)
715
+ for ema_param, param in zip(self.ema_denoiser.parameters(), self.denoiser.parameters()):
716
+ ema_param.mul_(decay).add_(param.detach().float(), alpha=1.0 - decay)
717
+
718
+ def compute_training_loss(
719
+ self,
720
+ x: torch.Tensor,
721
+ y: Iterable[Any],
722
+ scheduler: torch.nn.Module,
723
+ metadata: Optional[Dict[str, Any]] = None,
724
+ ) -> Dict[str, torch.Tensor]:
725
+ if self.diffusion_trainer is None:
726
+ raise RuntimeError("diffusion_trainer is not configured.")
727
+ metadata = {} if metadata is None else metadata
728
+
729
+ with torch.no_grad():
730
+ x = self.encode(x)
731
+ condition, uncondition = self.get_conditioning(y, metadata)
732
+
733
+ return self.diffusion_trainer(
734
+ self.denoiser,
735
+ self.ema_denoiser if self.ema_denoiser is not None else self.denoiser,
736
+ scheduler,
737
+ x,
738
+ condition,
739
+ uncondition,
740
+ metadata,
741
+ )
742
+
743
+ __all__ = [
744
+ "PixNerDiT",
745
+ "LabelConditioner",
746
+ "PixelAE",
747
+ "PixNerdTransformer2DModel",
748
+ "PixNerdTransformer2DModelOutput",
749
+ ]
PixNerd-XL-16-512/README.md ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PixNerd-XL-16-512
2
+
3
+ Self-contained PixNerd-XL/16 checkpoint inside [`BiliSakura/PixNerd-diffusers`](https://huggingface.co/BiliSakura/PixNerd-diffusers). Runtime dependencies: this folder + PyPI `diffusers`/`torch` only.
4
+
5
+ ## Hub path
6
+
7
+ `BiliSakura/PixNerd-diffusers/PixNerd-XL-16-512`
8
+
9
+ ## Layout
10
+
11
+ ```text
12
+ PixNerd-XL-16-512/
13
+ ├── pipeline.py
14
+ ├── model_index.json
15
+ ├── conversion_metadata.json
16
+ ├── transformer/
17
+ └── scheduler/
18
+ ```
19
+
20
+ ## Load
21
+
22
+ ```python
23
+ import torch
24
+ from diffusers import DiffusionPipeline
25
+
26
+ pipe = DiffusionPipeline.from_pretrained(
27
+ "BiliSakura/PixNerd-diffusers/PixNerd-XL-16-512",
28
+ trust_remote_code=True,
29
+ torch_dtype=torch.float32,
30
+ ).to("cuda")
31
+
32
+ images = pipe(
33
+ prompt=207,
34
+ height=512,
35
+ width=512,
36
+ num_inference_steps=25,
37
+ guidance_scale=4.0,
38
+ timeshift=3.0,
39
+ order=2,
40
+ ).images
41
+ ```
PixNerd-XL-16-512/model_index.json CHANGED
@@ -1,12 +1,15 @@
1
  {
2
- "_class_name": "PixNerdPipeline",
3
- "_diffusers_version": "0.30.0",
 
 
 
4
  "scheduler": [
5
- "diffusers_modules.local.scheduling_pixnerd_flow_match",
6
  "PixNerdFlowMatchScheduler"
7
  ],
8
  "transformer": [
9
- "diffusers_modules.local.modeling_pixnerd_transformer_2d",
10
  "PixNerdTransformer2DModel"
11
  ]
12
  }
 
1
  {
2
+ "_class_name": [
3
+ "pipeline",
4
+ "PixNerdPipeline"
5
+ ],
6
+ "_diffusers_version": "0.36.0",
7
  "scheduler": [
8
+ "scheduling_pixnerd_flow_match",
9
  "PixNerdFlowMatchScheduler"
10
  ],
11
  "transformer": [
12
+ "modeling_pixnerd_transformer_2d",
13
  "PixNerdTransformer2DModel"
14
  ]
15
  }
PixNerd-XL-16-512/pipeline.py CHANGED
@@ -1,7 +1,9 @@
1
  from __future__ import annotations
2
 
 
3
  from dataclasses import dataclass
4
- from typing import List, Optional, Sequence, Union
 
5
 
6
  import torch
7
  from diffusers import DiffusionPipeline
@@ -9,10 +11,8 @@ from diffusers.image_processor import VaeImageProcessor
9
  from diffusers.utils import BaseOutput
10
  from PIL import Image
11
 
12
- from .modeling_pixnerd_transformer_2d import PixNerdTransformer2DModel
13
- from .scheduling_pixnerd_flow_match import PixNerdFlowMatchScheduler
14
-
15
  ConditioningInput = Union[str, int, Sequence[Union[str, int]]]
 
16
 
17
 
18
  @dataclass
@@ -27,9 +27,11 @@ class PixNerdPipeline(DiffusionPipeline):
27
  def __init__(
28
  self,
29
  transformer,
30
- scheduler: PixNerdFlowMatchScheduler,
31
  vae=None,
32
  conditioner=None,
 
 
33
  ):
34
  super().__init__()
35
  if vae is None:
@@ -46,6 +48,170 @@ class PixNerdPipeline(DiffusionPipeline):
46
  )
47
  self.image_processor = VaeImageProcessor(vae_scale_factor=1)
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  @staticmethod
50
  def _fp_to_uint8(image: torch.Tensor) -> torch.Tensor:
51
  return torch.clip_((image + 1) * 127.5 + 0.5, 0, 255).to(torch.uint8)
@@ -71,10 +237,11 @@ class PixNerdPipeline(DiffusionPipeline):
71
  num_images_per_prompt: int,
72
  ):
73
  prompts = self._repeat(self._to_list(prompt), num_images_per_prompt)
 
74
  metadata = {"device": self._execution_device}
75
  with torch.no_grad():
76
- cond, uncond = self.conditioner(prompts, metadata)
77
- return cond, uncond, prompts
78
 
79
  def prepare_latents(
80
  self,
@@ -124,9 +291,10 @@ class PixNerdPipeline(DiffusionPipeline):
124
  cond, default_uncond, prompts = self.encode_prompt(prompt, num_images_per_prompt)
125
  if negative_prompt is not None:
126
  negative = self._repeat(self._to_list(negative_prompt), num_images_per_prompt)
 
127
  metadata = {"device": self._execution_device}
128
  with torch.no_grad():
129
- _, uncond = self.conditioner(negative, metadata)
130
  else:
131
  uncond = default_uncond
132
  batch_size = len(prompts)
@@ -178,6 +346,7 @@ class PixNerdPipeline(DiffusionPipeline):
178
  return (output,)
179
  return PixNerdPipelineOutput(images=output)
180
 
 
181
  __all__ = [
182
  "PixNerdPipeline",
183
  "PixNerdPipelineOutput",
 
1
  from __future__ import annotations
2
 
3
+ import sys
4
  from dataclasses import dataclass
5
+ from pathlib import Path
6
+ from typing import List, Literal, Optional, Sequence, Union
7
 
8
  import torch
9
  from diffusers import DiffusionPipeline
 
11
  from diffusers.utils import BaseOutput
12
  from PIL import Image
13
 
 
 
 
14
  ConditioningInput = Union[str, int, Sequence[Union[str, int]]]
15
+ Language = Literal["en", "cn"]
16
 
17
 
18
  @dataclass
 
27
  def __init__(
28
  self,
29
  transformer,
30
+ scheduler,
31
  vae=None,
32
  conditioner=None,
33
+ id2label: Optional[dict[int, str]] = None,
34
+ id2label_cn: Optional[dict[int, str]] = None,
35
  ):
36
  super().__init__()
37
  if vae is None:
 
48
  )
49
  self.image_processor = VaeImageProcessor(vae_scale_factor=1)
50
 
51
+ if id2label is None and id2label_cn is None:
52
+ id2label, id2label_cn = self._load_repo_labels()
53
+ self._id2label = id2label or {}
54
+ self._id2label_cn = id2label_cn or {}
55
+ self.labels = self._build_label2id(self._id2label)
56
+ self.labels_cn = self._build_label2id(self._id2label_cn)
57
+ self._labels_loaded_from_path = bool(self._id2label or self._id2label_cn)
58
+
59
+ def _ensure_labels_loaded(self) -> None:
60
+ if self._labels_loaded_from_path:
61
+ return
62
+
63
+ path = getattr(getattr(self, "config", None), "_name_or_path", None) or getattr(self, "_name_or_path", None)
64
+ if not path:
65
+ return
66
+
67
+ id2label, id2label_cn = self._load_labels_for_path(path)
68
+ if id2label is None and id2label_cn is None:
69
+ self._labels_loaded_from_path = True
70
+ return
71
+
72
+ self._id2label = id2label or {}
73
+ self._id2label_cn = id2label_cn or {}
74
+ self.labels = self._build_label2id(self._id2label)
75
+ self.labels_cn = self._build_label2id(self._id2label_cn)
76
+ self._labels_loaded_from_path = True
77
+
78
+ @staticmethod
79
+ def _resolve_labels_dir(pretrained_model_name_or_path: Union[str, Path]) -> Optional[Path]:
80
+ path = Path(pretrained_model_name_or_path)
81
+ if not path.exists():
82
+ try:
83
+ from huggingface_hub import snapshot_download
84
+
85
+ path = Path(snapshot_download(pretrained_model_name_or_path))
86
+ except Exception:
87
+ return None
88
+
89
+ if (path / "model_index.json").exists():
90
+ labels_dir = path.parent / "labels"
91
+ else:
92
+ labels_dir = path / "labels"
93
+ return labels_dir if labels_dir.is_dir() else None
94
+
95
+ @classmethod
96
+ def _load_labels_for_path(
97
+ cls,
98
+ pretrained_model_name_or_path: Union[str, Path],
99
+ ) -> tuple[Optional[dict[int, str]], Optional[dict[int, str]]]:
100
+ labels_dir = cls._resolve_labels_dir(pretrained_model_name_or_path)
101
+ if labels_dir is None:
102
+ return None, None
103
+
104
+ labels_path = str(labels_dir)
105
+ inserted = False
106
+ if labels_path not in sys.path:
107
+ sys.path.insert(0, labels_path)
108
+ inserted = True
109
+ try:
110
+ from imagenet_labels import load_id2label
111
+
112
+ return (
113
+ load_id2label(labels_dir, lang="en"),
114
+ load_id2label(labels_dir, lang="cn"),
115
+ )
116
+ finally:
117
+ if inserted and labels_path in sys.path:
118
+ sys.path.remove(labels_path)
119
+
120
+ @staticmethod
121
+ def _load_repo_labels() -> tuple[Optional[dict[int, str]], Optional[dict[int, str]]]:
122
+ labels_dir = Path(__file__).resolve().parent.parent / "labels"
123
+ if not labels_dir.is_dir():
124
+ return None, None
125
+
126
+ labels_path = str(labels_dir)
127
+ inserted = False
128
+ if labels_path not in sys.path:
129
+ sys.path.insert(0, labels_path)
130
+ inserted = True
131
+ try:
132
+ from imagenet_labels import load_id2label
133
+
134
+ return (
135
+ load_id2label(labels_dir, lang="en"),
136
+ load_id2label(labels_dir, lang="cn"),
137
+ )
138
+ finally:
139
+ if inserted and labels_path in sys.path:
140
+ sys.path.remove(labels_path)
141
+
142
+ @classmethod
143
+ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
144
+ pipe = super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
145
+ id2label, id2label_cn = cls._load_labels_for_path(pretrained_model_name_or_path)
146
+ if id2label is not None or id2label_cn is not None:
147
+ pipe._id2label = id2label or {}
148
+ pipe._id2label_cn = id2label_cn or {}
149
+ pipe.labels = cls._build_label2id(pipe._id2label)
150
+ pipe.labels_cn = cls._build_label2id(pipe._id2label_cn)
151
+ return pipe
152
+
153
+ @staticmethod
154
+ def _build_label2id(id2label: dict[int, str]) -> dict[str, int]:
155
+ label2id: dict[str, int] = {}
156
+ for class_id, value in id2label.items():
157
+ for synonym in value.split(","):
158
+ synonym = synonym.strip()
159
+ if synonym:
160
+ label2id[synonym] = int(class_id)
161
+ return dict(sorted(label2id.items()))
162
+
163
+ @property
164
+ def id2label(self) -> dict[int, str]:
165
+ self._ensure_labels_loaded()
166
+ return self._id2label
167
+
168
+ @property
169
+ def id2label_cn(self) -> dict[int, str]:
170
+ self._ensure_labels_loaded()
171
+ return self._id2label_cn
172
+
173
+ def get_label_ids(
174
+ self,
175
+ labels: Union[str, List[str]],
176
+ *,
177
+ lang: Language = "en",
178
+ ) -> List[int]:
179
+ self._ensure_labels_loaded()
180
+ if isinstance(labels, str):
181
+ labels = [labels]
182
+
183
+ label2id = self.labels if lang == "en" else self.labels_cn
184
+ if not label2id:
185
+ raise ValueError(
186
+ f"No {lang} labels loaded. Ensure `labels/id2label_{lang}.json` exists next to the variant folder."
187
+ )
188
+
189
+ missing = [label for label in labels if label not in label2id]
190
+ if missing:
191
+ preview = ", ".join(list(label2id.keys())[:8])
192
+ raise ValueError(
193
+ f"Unknown label(s) for lang={lang!r}: {missing}. Example valid labels: {preview}, ..."
194
+ )
195
+ return [label2id[label] for label in labels]
196
+
197
+ def _resolve_prompt_item(self, value: Union[str, int]) -> int:
198
+ if isinstance(value, int):
199
+ return value
200
+ if value.isdigit():
201
+ return int(value)
202
+ if value in self.labels:
203
+ return self.labels[value]
204
+ if value in self.labels_cn:
205
+ return self.labels_cn[value]
206
+ raise ValueError(
207
+ f"Unknown class label {value!r}. Pass an ImageNet class id or a synonym from "
208
+ "`pipe.labels` / `pipe.labels_cn`."
209
+ )
210
+
211
+ def _resolve_prompts(self, prompts: List[Union[str, int]]) -> List[int]:
212
+ self._ensure_labels_loaded()
213
+ return [self._resolve_prompt_item(prompt) for prompt in prompts]
214
+
215
  @staticmethod
216
  def _fp_to_uint8(image: torch.Tensor) -> torch.Tensor:
217
  return torch.clip_((image + 1) * 127.5 + 0.5, 0, 255).to(torch.uint8)
 
237
  num_images_per_prompt: int,
238
  ):
239
  prompts = self._repeat(self._to_list(prompt), num_images_per_prompt)
240
+ resolved = self._resolve_prompts(prompts)
241
  metadata = {"device": self._execution_device}
242
  with torch.no_grad():
243
+ cond, uncond = self.conditioner(resolved, metadata)
244
+ return cond, uncond, resolved
245
 
246
  def prepare_latents(
247
  self,
 
291
  cond, default_uncond, prompts = self.encode_prompt(prompt, num_images_per_prompt)
292
  if negative_prompt is not None:
293
  negative = self._repeat(self._to_list(negative_prompt), num_images_per_prompt)
294
+ resolved_negative = self._resolve_prompts(negative)
295
  metadata = {"device": self._execution_device}
296
  with torch.no_grad():
297
+ _, uncond = self.conditioner(resolved_negative, metadata)
298
  else:
299
  uncond = default_uncond
300
  batch_size = len(prompts)
 
346
  return (output,)
347
  return PixNerdPipelineOutput(images=output)
348
 
349
+
350
  __all__ = [
351
  "PixNerdPipeline",
352
  "PixNerdPipelineOutput",
PixNerd-XL-16-512/scheduler/scheduling_pixnerd_flow_match.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any, Dict, List, Optional, Tuple, Union
5
+
6
+ import torch
7
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
8
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
9
+ from diffusers.utils import BaseOutput
10
+
11
+ @dataclass
12
+ class PixNerdSchedulerOutput(BaseOutput):
13
+ prev_sample: torch.Tensor
14
+
15
+
16
+ class PixNerdFlowMatchScheduler(SchedulerMixin, ConfigMixin):
17
+ """
18
+ Diffusers-compatible scheduler wrapper for PixNerd's AdamLM flow-matching sampler.
19
+ """
20
+
21
+ config_name = "scheduler_config.json"
22
+ order = 1
23
+ init_noise_sigma = 1.0
24
+
25
+ @staticmethod
26
+ def _lagrange_coeffs(order: int, pre_ts: torch.Tensor, t_start: torch.Tensor, t_end: torch.Tensor) -> List[float]:
27
+ ts = [float(v) for v in pre_ts[-order:].tolist()]
28
+ a = float(t_start)
29
+ b = float(t_end)
30
+
31
+ if order == 1:
32
+ return [1.0]
33
+ if order == 2:
34
+ t1, t2 = ts
35
+ int1 = 0.5 / (t1 - t2) * ((b - t2) ** 2 - (a - t2) ** 2)
36
+ int2 = 0.5 / (t2 - t1) * ((b - t1) ** 2 - (a - t1) ** 2)
37
+ total = int1 + int2
38
+ return [int1 / total, int2 / total]
39
+ if order == 3:
40
+ t1, t2, t3 = ts
41
+ int1_denom = (t1 - t2) * (t1 - t3)
42
+ int1 = ((1 / 3) * b**3 - 0.5 * (t2 + t3) * b**2 + (t2 * t3) * b) - (
43
+ (1 / 3) * a**3 - 0.5 * (t2 + t3) * a**2 + (t2 * t3) * a
44
+ )
45
+ int1 = int1 / int1_denom
46
+ int2_denom = (t2 - t1) * (t2 - t3)
47
+ int2 = ((1 / 3) * b**3 - 0.5 * (t1 + t3) * b**2 + (t1 * t3) * b) - (
48
+ (1 / 3) * a**3 - 0.5 * (t1 + t3) * a**2 + (t1 * t3) * a
49
+ )
50
+ int2 = int2 / int2_denom
51
+ int3_denom = (t3 - t1) * (t3 - t2)
52
+ int3 = ((1 / 3) * b**3 - 0.5 * (t1 + t2) * b**2 + (t1 * t2) * b) - (
53
+ (1 / 3) * a**3 - 0.5 * (t1 + t2) * a**2 + (t1 * t2) * a
54
+ )
55
+ int3 = int3 / int3_denom
56
+ total = int1 + int2 + int3
57
+ return [int1 / total, int2 / total, int3 / total]
58
+ if order == 4:
59
+ t1, t2, t3, t4 = ts
60
+ int1_denom = (t1 - t2) * (t1 - t3) * (t1 - t4)
61
+ int1 = ((1 / 4) * b**4 - (1 / 3) * (t2 + t3 + t4) * b**3 + 0.5 * (t3 * t4 + t2 * t3 + t2 * t4) * b**2 - (t2 * t3 * t4) * b) - (
62
+ (1 / 4) * a**4 - (1 / 3) * (t2 + t3 + t4) * a**3 + 0.5 * (t3 * t4 + t2 * t3 + t2 * t4) * a**2 - (t2 * t3 * t4) * a
63
+ )
64
+ int1 = int1 / int1_denom
65
+ int2_denom = (t2 - t1) * (t2 - t3) * (t2 - t4)
66
+ int2 = ((1 / 4) * b**4 - (1 / 3) * (t1 + t3 + t4) * b**3 + 0.5 * (t3 * t4 + t1 * t3 + t1 * t4) * b**2 - (t1 * t3 * t4) * b) - (
67
+ (1 / 4) * a**4 - (1 / 3) * (t1 + t3 + t4) * a**3 + 0.5 * (t3 * t4 + t1 * t3 + t1 * t4) * a**2 - (t1 * t3 * t4) * a
68
+ )
69
+ int2 = int2 / int2_denom
70
+ int3_denom = (t3 - t1) * (t3 - t2) * (t3 - t4)
71
+ int3 = ((1 / 4) * b**4 - (1 / 3) * (t1 + t2 + t4) * b**3 + 0.5 * (t4 * t2 + t1 * t2 + t1 * t4) * b**2 - (t1 * t2 * t4) * b) - (
72
+ (1 / 4) * a**4 - (1 / 3) * (t1 + t2 + t4) * a**3 + 0.5 * (t4 * t2 + t1 * t2 + t1 * t4) * a**2 - (t1 * t2 * t4) * a
73
+ )
74
+ int3 = int3 / int3_denom
75
+ int4_denom = (t4 - t1) * (t4 - t2) * (t4 - t3)
76
+ int4 = ((1 / 4) * b**4 - (1 / 3) * (t1 + t2 + t3) * b**3 + 0.5 * (t3 * t2 + t1 * t2 + t1 * t3) * b**2 - (t1 * t2 * t3) * b) - (
77
+ (1 / 4) * a**4 - (1 / 3) * (t1 + t2 + t3) * a**3 + 0.5 * (t3 * t2 + t1 * t2 + t1 * t3) * a**2 - (t1 * t2 * t3) * a
78
+ )
79
+ int4 = int4 / int4_denom
80
+ total = int1 + int2 + int3 + int4
81
+ return [int1 / total, int2 / total, int3 / total, int4 / total]
82
+ raise ValueError(f"Unsupported solver order: {order}.")
83
+
84
+ @register_to_config
85
+ def __init__(
86
+ self,
87
+ num_train_timesteps: int = 1000,
88
+ num_inference_steps: int = 25,
89
+ guidance_scale: float = 4.0,
90
+ timeshift: float = 3.0,
91
+ order: int = 2,
92
+ guidance_interval_min: float = 0.0,
93
+ guidance_interval_max: float = 1.0,
94
+ last_step: Optional[float] = None,
95
+ ) -> None:
96
+ self.num_inference_steps = int(num_inference_steps)
97
+ self.guidance_scale = float(guidance_scale)
98
+ self.timeshift = float(timeshift)
99
+ self.order = int(order)
100
+ self.guidance_interval_min = float(guidance_interval_min)
101
+ self.guidance_interval_max = float(guidance_interval_max)
102
+ self.last_step = last_step
103
+ self._reset_state()
104
+
105
+ @classmethod
106
+ def from_sampler_spec(cls, sampler_spec: Dict[str, Any]) -> "PixNerdFlowMatchScheduler":
107
+ init_args = dict(sampler_spec.get("init_args", {}))
108
+ return cls(
109
+ num_inference_steps=int(init_args.get("num_steps", 25)),
110
+ guidance_scale=float(init_args.get("guidance", 4.0)),
111
+ timeshift=float(init_args.get("timeshift", 3.0)),
112
+ order=int(init_args.get("order", 2)),
113
+ guidance_interval_min=float(init_args.get("guidance_interval_min", 0.0)),
114
+ guidance_interval_max=float(init_args.get("guidance_interval_max", 1.0)),
115
+ last_step=init_args.get("last_step"),
116
+ )
117
+
118
+ def _reset_state(self) -> None:
119
+ self.timesteps: Optional[torch.Tensor] = None
120
+ self._timedeltas: Optional[torch.Tensor] = None
121
+ self._solver_coeffs = None
122
+ self._model_outputs = []
123
+ self._step_index = 0
124
+
125
+ @staticmethod
126
+ def _shift_respace_fn(t: torch.Tensor, shift: float = 3.0) -> torch.Tensor:
127
+ return t / (t + (1 - t) * shift)
128
+
129
+ def _build_solver_state(
130
+ self,
131
+ num_inference_steps: int,
132
+ timeshift: float,
133
+ device: Optional[Union[str, torch.device]] = None,
134
+ ) -> Tuple[torch.Tensor, torch.Tensor, List[List[float]]]:
135
+ last_step = self.last_step
136
+ if last_step is None:
137
+ last_step = 1.0 / float(num_inference_steps)
138
+
139
+ endpoints = torch.linspace(0.0, 1 - float(last_step), int(num_inference_steps), dtype=torch.float32)
140
+ endpoints = torch.cat([endpoints, torch.tensor([1.0], dtype=torch.float32)], dim=0)
141
+ timesteps = self._shift_respace_fn(endpoints, timeshift).to(device=device)
142
+ timedeltas = (timesteps[1:] - timesteps[:-1]).to(device=device)
143
+
144
+ solver_coeffs: List[List[float]] = [[] for _ in range(int(num_inference_steps))]
145
+ for i in range(int(num_inference_steps)):
146
+ order = min(self.order, i + 1)
147
+ pre_ts = timesteps[: i + 1]
148
+ coeffs = self._lagrange_coeffs(order, pre_ts, pre_ts[i], timesteps[i + 1])
149
+ solver_coeffs[i] = coeffs
150
+ return timesteps[:-1], timedeltas, solver_coeffs
151
+
152
+ def set_timesteps(
153
+ self,
154
+ num_inference_steps: Optional[int] = None,
155
+ device: Optional[Union[str, torch.device]] = None,
156
+ timeshift: Optional[float] = None,
157
+ guidance_scale: Optional[float] = None,
158
+ order: Optional[int] = None,
159
+ **kwargs: Any,
160
+ ) -> None:
161
+ if num_inference_steps is not None:
162
+ self.num_inference_steps = int(num_inference_steps)
163
+ if timeshift is not None:
164
+ self.timeshift = float(timeshift)
165
+ if guidance_scale is not None:
166
+ self.guidance_scale = float(guidance_scale)
167
+ if order is not None:
168
+ self.order = int(order)
169
+
170
+ timesteps, timedeltas, solver_coeffs = self._build_solver_state(
171
+ self.num_inference_steps,
172
+ self.timeshift,
173
+ device=device,
174
+ )
175
+ self.timesteps = timesteps
176
+ self._timedeltas = timedeltas
177
+ self._solver_coeffs = solver_coeffs
178
+ self._model_outputs = []
179
+ self._step_index = 0
180
+
181
+ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[torch.Tensor] = None) -> torch.Tensor:
182
+ return sample
183
+
184
+ def classifier_free_guidance(self, model_output: torch.Tensor) -> torch.Tensor:
185
+ if model_output.shape[0] % 2 != 0:
186
+ raise ValueError("Classifier-free guidance expects concatenated unconditional/conditional batches.")
187
+ uncond, cond = model_output.chunk(2, dim=0)
188
+ return uncond + self.guidance_scale * (cond - uncond)
189
+
190
+ def step(
191
+ self,
192
+ model_output: torch.Tensor,
193
+ timestep: Union[torch.Tensor, float, int],
194
+ sample: torch.Tensor,
195
+ return_dict: bool = True,
196
+ **kwargs: Any,
197
+ ) -> Union[PixNerdSchedulerOutput, Tuple[torch.Tensor]]:
198
+ if self.timesteps is None or self._timedeltas is None or self._solver_coeffs is None:
199
+ raise RuntimeError("`set_timesteps` must be called before `step`.")
200
+ if self._step_index >= len(self._solver_coeffs):
201
+ raise RuntimeError("Scheduler step index exceeded configured timesteps.")
202
+
203
+ coeffs = self._solver_coeffs[self._step_index]
204
+ self._model_outputs.append(model_output)
205
+ order = len(coeffs)
206
+ pred = torch.zeros_like(model_output)
207
+ recent = self._model_outputs[-order:]
208
+ for coeff, output in zip(coeffs, recent):
209
+ pred = pred + coeff * output
210
+
211
+ prev_sample = sample + pred * self._timedeltas[self._step_index]
212
+ self._step_index += 1
213
+
214
+ if not return_dict:
215
+ return (prev_sample,)
216
+ return PixNerdSchedulerOutput(prev_sample=prev_sample)
217
+
218
+ def add_noise(
219
+ self,
220
+ original_samples: torch.Tensor,
221
+ noise: torch.Tensor,
222
+ timesteps: torch.Tensor,
223
+ ) -> torch.Tensor:
224
+ alpha = timesteps.view(-1, 1, 1, 1)
225
+ sigma = (1.0 - timesteps).view(-1, 1, 1, 1)
226
+ return alpha * original_samples + sigma * noise
227
+
228
+ __all__ = [
229
+ "PixNerdFlowMatchScheduler",
230
+ "PixNerdSchedulerOutput",
231
+ ]
PixNerd-XL-16-512/transformer/config.json CHANGED
@@ -3,13 +3,13 @@
3
  "_diffusers_version": "0.36.0",
4
  "compile_denoiser": false,
5
  "conditioner_spec": {
6
- "class_path": "diffusers_modules.local.modeling_pixnerd_transformer_2d.LabelConditioner",
7
  "init_args": {
8
  "num_classes": 1000
9
  }
10
  },
11
  "denoiser_spec": {
12
- "class_path": "diffusers_modules.local.modeling_pixnerd_transformer_2d.PixNerDiT",
13
  "init_args": {
14
  "hidden_size": 1152,
15
  "hidden_size_x": 64,
@@ -26,7 +26,7 @@
26
  "ema_decay": 0.9999,
27
  "use_ema": true,
28
  "vae_spec": {
29
- "class_path": "diffusers_modules.local.modeling_pixnerd_transformer_2d.PixelAE",
30
  "init_args": {
31
  "scale": 1.0
32
  }
 
3
  "_diffusers_version": "0.36.0",
4
  "compile_denoiser": false,
5
  "conditioner_spec": {
6
+ "class_path": "modeling_pixnerd_transformer_2d.LabelConditioner",
7
  "init_args": {
8
  "num_classes": 1000
9
  }
10
  },
11
  "denoiser_spec": {
12
+ "class_path": "modeling_pixnerd_transformer_2d.PixNerDiT",
13
  "init_args": {
14
  "hidden_size": 1152,
15
  "hidden_size_x": 64,
 
26
  "ema_decay": 0.9999,
27
  "use_ema": true,
28
  "vae_spec": {
29
+ "class_path": "modeling_pixnerd_transformer_2d.PixelAE",
30
  "init_args": {
31
  "scale": 1.0
32
  }
PixNerd-XL-16-512/transformer/modeling_pixnerd_transformer_2d.py ADDED
@@ -0,0 +1,749 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import copy
4
+ import importlib
5
+ import math
6
+ import sys
7
+ from dataclasses import dataclass
8
+ from functools import lru_cache
9
+ from typing import Any, Dict, Iterable, List, Optional, Tuple
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
14
+ from diffusers.models.modeling_utils import ModelMixin
15
+ from diffusers.utils import BaseOutput
16
+ from torch.nn.functional import scaled_dot_product_attention
17
+
18
+ class BaseAE(torch.nn.Module):
19
+ def __init__(self, scale=1.0, shift=0.0):
20
+ super().__init__()
21
+ self.scale = scale
22
+ self.shift = shift
23
+
24
+ def encode(self, x):
25
+ return self._impl_encode(x) #.to(torch.bfloat16)
26
+
27
+ # @torch.autocast("cuda", dtype=torch.bfloat16)
28
+ def decode(self, x):
29
+ return self._impl_decode(x) #.to(torch.bfloat16)
30
+
31
+ def _impl_encode(self, x):
32
+ raise NotImplementedError
33
+
34
+ def _impl_decode(self, x):
35
+ raise NotImplementedError
36
+
37
+ def uint82fp(x):
38
+ x = x.to(torch.float32)
39
+ x = (x - 127.5) / 127.5
40
+ return x
41
+
42
+ def fp2uint8(x):
43
+ x = torch.clip_((x + 1) * 127.5 + 0.5, 0, 255).to(torch.uint8)
44
+ return x
45
+
46
+
47
+ class PixelAE(BaseAE):
48
+ def __init__(self, scale=1.0, shift=0.0):
49
+ super().__init__(scale, shift)
50
+
51
+ def _impl_encode(self, x):
52
+ return x/self.scale+self.shift
53
+
54
+ def _impl_decode(self, x):
55
+ return (x-self.shift)*self.scale
56
+
57
+
58
+ def resolve_conditioner_device(metadata: dict, fallback: torch.device | None = None) -> torch.device:
59
+ if metadata is None:
60
+ metadata = {}
61
+ if "device" in metadata and metadata["device"] is not None:
62
+ return torch.device(metadata["device"])
63
+ if fallback is not None:
64
+ return fallback
65
+ return torch.device("cuda" if torch.cuda.is_available() else "cpu")
66
+
67
+
68
+ class BaseConditioner(nn.Module):
69
+ def __init__(self):
70
+ super(BaseConditioner, self).__init__()
71
+
72
+ def _impl_condition(self, y, metadata)->torch.Tensor:
73
+ raise NotImplementedError()
74
+
75
+ def _impl_uncondition(self, y, metadata)->torch.Tensor:
76
+ raise NotImplementedError()
77
+
78
+ @torch.no_grad()
79
+ def __call__(self, y, metadata:dict={}):
80
+ condition = self._impl_condition(y, metadata)
81
+ uncondition = self._impl_uncondition(y, metadata)
82
+ if condition.dtype in [torch.float64, torch.float32, torch.float16]:
83
+ condition = condition.to(torch.bfloat16)
84
+ if uncondition.dtype in [torch.float64,torch.float32, torch.float16]:
85
+ uncondition = uncondition.to(torch.bfloat16)
86
+ return condition, uncondition
87
+
88
+
89
+ class ComposeConditioner(BaseConditioner):
90
+ def __init__(self, conditioners:List[BaseConditioner]):
91
+ super().__init__()
92
+ self.conditioners = conditioners
93
+
94
+ def _impl_condition(self, y, metadata):
95
+ condition = []
96
+ for conditioner in self.conditioners:
97
+ condition.append(conditioner._impl_condition(y, metadata))
98
+ condition = torch.cat(condition, dim=1)
99
+ return condition
100
+
101
+ def _impl_uncondition(self, y, metadata):
102
+ uncondition = []
103
+ for conditioner in self.conditioners:
104
+ uncondition.append(conditioner._impl_uncondition(y, metadata))
105
+ uncondition = torch.cat(uncondition, dim=1)
106
+ return uncondition
107
+
108
+
109
+ class LabelConditioner(BaseConditioner):
110
+ def __init__(self, num_classes):
111
+ super().__init__()
112
+ self.null_condition = num_classes
113
+
114
+ def _impl_condition(self, y, metadata):
115
+ device = resolve_conditioner_device(metadata)
116
+ return torch.tensor(y, device=device).long()
117
+
118
+ def _impl_uncondition(self, y, metadata):
119
+ device = resolve_conditioner_device(metadata)
120
+ return torch.full((len(y),), self.null_condition, dtype=torch.long, device=device)
121
+
122
+
123
+ def modulate(x, shift, scale):
124
+ return x * (1 + scale) + shift
125
+
126
+ class Embed(nn.Module):
127
+ def __init__(
128
+ self,
129
+ in_chans: int = 3,
130
+ embed_dim: int = 768,
131
+ norm_layer = None,
132
+ bias: bool = True,
133
+ ):
134
+ super().__init__()
135
+ self.in_chans = in_chans
136
+ self.embed_dim = embed_dim
137
+ self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
138
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
139
+ def forward(self, x):
140
+ x = self.proj(x)
141
+ x = self.norm(x)
142
+ return x
143
+
144
+ class TimestepEmbedder(nn.Module):
145
+
146
+ def __init__(self, hidden_size, frequency_embedding_size=256):
147
+ super().__init__()
148
+ self.mlp = nn.Sequential(
149
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
150
+ nn.SiLU(),
151
+ nn.Linear(hidden_size, hidden_size, bias=True),
152
+ )
153
+ self.frequency_embedding_size = frequency_embedding_size
154
+
155
+ @staticmethod
156
+ def timestep_embedding(t, dim, max_period=10):
157
+ half = dim // 2
158
+ freqs = torch.exp(
159
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
160
+ )
161
+ args = t[..., None].float() * freqs[None, ...]
162
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
163
+ if dim % 2:
164
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
165
+ return embedding
166
+
167
+ def forward(self, t):
168
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
169
+ t_emb = self.mlp(t_freq)
170
+ return t_emb
171
+
172
+ class LabelEmbedder(nn.Module):
173
+ def __init__(self, num_classes, hidden_size):
174
+ super().__init__()
175
+ self.embedding_table = nn.Embedding(num_classes, hidden_size)
176
+ self.num_classes = num_classes
177
+
178
+ def forward(self, labels,):
179
+ embeddings = self.embedding_table(labels)
180
+ return embeddings
181
+
182
+ class FinalLayer(nn.Module):
183
+ def __init__(self, hidden_size, out_channels):
184
+ super().__init__()
185
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
186
+ self.linear = nn.Linear(hidden_size, out_channels, bias=True)
187
+ self.adaLN_modulation = nn.Sequential(
188
+ nn.Linear(hidden_size, 2*hidden_size, bias=True)
189
+ )
190
+
191
+ def forward(self, x, c):
192
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
193
+ x = modulate(self.norm_final(x), shift, scale)
194
+ x = self.linear(x)
195
+ return x
196
+
197
+ class RMSNorm(nn.Module):
198
+ def __init__(self, hidden_size, eps=1e-6):
199
+ """
200
+ LlamaRMSNorm is equivalent to T5LayerNorm
201
+ """
202
+ super().__init__()
203
+ self.weight = nn.Parameter(torch.ones(hidden_size))
204
+ self.variance_epsilon = eps
205
+
206
+ def forward(self, hidden_states):
207
+ input_dtype = hidden_states.dtype
208
+ hidden_states = hidden_states.to(torch.float32)
209
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
210
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
211
+ return self.weight * hidden_states.to(input_dtype)
212
+
213
+ class FeedForward(nn.Module):
214
+ def __init__(
215
+ self,
216
+ dim: int,
217
+ hidden_dim: int,
218
+ ):
219
+ super().__init__()
220
+ hidden_dim = int(2 * hidden_dim / 3)
221
+ self.w1 = nn.Linear(dim, hidden_dim, bias=False)
222
+ self.w3 = nn.Linear(dim, hidden_dim, bias=False)
223
+ self.w2 = nn.Linear(hidden_dim, dim, bias=False)
224
+ def forward(self, x):
225
+ x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
226
+ return x
227
+
228
+ def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
229
+ # assert H * H == end
230
+ # flat_patch_pos = torch.linspace(-1, 1, end) # N = end
231
+ x_pos = torch.linspace(0, scale, width)
232
+ y_pos = torch.linspace(0, scale, height)
233
+ y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
234
+ y_pos = y_pos.reshape(-1)
235
+ x_pos = x_pos.reshape(-1)
236
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
237
+ x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
238
+ y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
239
+ x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
240
+ y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
241
+ freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
242
+ freqs_cis = freqs_cis.reshape(height*width, -1)
243
+ return freqs_cis
244
+
245
+
246
+ def apply_rotary_emb(
247
+ xq: torch.Tensor,
248
+ xk: torch.Tensor,
249
+ freqs_cis: torch.Tensor,
250
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
251
+ freqs_cis = freqs_cis[None, :, None, :]
252
+ # xq : B N H Hc
253
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
254
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
255
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
256
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
257
+ return xq_out.type_as(xq), xk_out.type_as(xk)
258
+
259
+
260
+ class RAttention(nn.Module):
261
+ def __init__(
262
+ self,
263
+ dim: int,
264
+ num_heads: int = 8,
265
+ qkv_bias: bool = False,
266
+ qk_norm: bool = True,
267
+ attn_drop: float = 0.,
268
+ proj_drop: float = 0.,
269
+ norm_layer: nn.Module = RMSNorm,
270
+ ) -> None:
271
+ super().__init__()
272
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
273
+
274
+ self.dim = dim
275
+ self.num_heads = num_heads
276
+ self.head_dim = dim // num_heads
277
+ self.scale = self.head_dim ** -0.5
278
+
279
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
280
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
281
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
282
+ self.attn_drop = nn.Dropout(attn_drop)
283
+ self.proj = nn.Linear(dim, dim)
284
+ self.proj_drop = nn.Dropout(proj_drop)
285
+
286
+ def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
287
+ B, N, C = x.shape
288
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
289
+ q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
290
+ q = self.q_norm(q)
291
+ k = self.k_norm(k)
292
+ q, k = apply_rotary_emb(q, k, freqs_cis=pos)
293
+ q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
294
+ k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
295
+ v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
296
+
297
+ x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
298
+
299
+ x = x.transpose(1, 2).reshape(B, N, C)
300
+ x = self.proj(x)
301
+ x = self.proj_drop(x)
302
+ return x
303
+
304
+
305
+
306
+ class FlattenDiTBlock(nn.Module):
307
+ def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
308
+ super().__init__()
309
+ self.norm1 = RMSNorm(hidden_size, eps=1e-6)
310
+ self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
311
+ self.norm2 = RMSNorm(hidden_size, eps=1e-6)
312
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
313
+ self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
314
+ self.adaLN_modulation = nn.Sequential(
315
+ nn.Linear(hidden_size, 6 * hidden_size, bias=True)
316
+ )
317
+
318
+ def forward(self, x, c, pos, mask=None):
319
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
320
+ x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
321
+ x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
322
+ return x
323
+
324
+ class NerfEmbedder(nn.Module):
325
+ def __init__(self, in_channels, hidden_size_input, max_freqs):
326
+ super().__init__()
327
+ self.max_freqs = max_freqs
328
+ self.hidden_size_input = hidden_size_input
329
+ self.embedder = nn.Sequential(
330
+ nn.Linear(in_channels+max_freqs**2, hidden_size_input, bias=True),
331
+ )
332
+
333
+ @lru_cache
334
+ def fetch_pos(self, patch_size, device, dtype):
335
+ pos_x = torch.linspace(0, 1, patch_size, device=device, dtype=dtype)
336
+ pos_y = torch.linspace(0, 1, patch_size, device=device, dtype=dtype)
337
+ pos_y, pos_x = torch.meshgrid(pos_y, pos_x, indexing="ij")
338
+ pos_x = pos_x.reshape(-1, 1, 1)
339
+ pos_y = pos_y.reshape(-1, 1, 1)
340
+
341
+ freqs = torch.linspace(0, self.max_freqs, self.max_freqs, dtype=dtype, device=device)
342
+ freqs_x = freqs[None, :, None]
343
+ freqs_y = freqs[None, None, :]
344
+ coeffs = (1 + freqs_x * freqs_y) ** -1
345
+ dct_x = torch.cos(pos_x * freqs_x * torch.pi)
346
+ dct_y = torch.cos(pos_y * freqs_y * torch.pi)
347
+ dct = (dct_x * dct_y * coeffs).view(1, -1, self.max_freqs ** 2)
348
+ return dct
349
+
350
+
351
+ def forward(self, inputs):
352
+ B, P2, C = inputs.shape
353
+ patch_size = int(P2 ** 0.5)
354
+ device = inputs.device
355
+ dtype = inputs.dtype
356
+ dct = self.fetch_pos(patch_size, device, dtype)
357
+ dct = dct.repeat(B, 1, 1)
358
+ inputs = torch.cat([inputs, dct], dim=-1)
359
+ inputs = self.embedder(inputs)
360
+ return inputs
361
+
362
+
363
+ class NerfBlock(nn.Module):
364
+ def __init__(self, hidden_size_s, hidden_size_x, mlp_ratio=4):
365
+ super().__init__()
366
+ self.param_generator1 = nn.Sequential(
367
+ nn.Linear(hidden_size_s, 2*hidden_size_x**2*mlp_ratio, bias=True),
368
+ )
369
+ self.norm = RMSNorm(hidden_size_x, eps=1e-6)
370
+ self.mlp_ratio = mlp_ratio
371
+ def forward(self, x, s):
372
+ batch_size, num_x, hidden_size_x = x.shape
373
+ mlp_params1 = self.param_generator1(s)
374
+ fc1_param1, fc2_param1 = mlp_params1.chunk(2, dim=-1)
375
+ fc1_param1 = fc1_param1.view(batch_size, hidden_size_x, hidden_size_x*self.mlp_ratio)
376
+ fc2_param1 = fc2_param1.view(batch_size, hidden_size_x*self.mlp_ratio, hidden_size_x)
377
+
378
+ # normalize fc1
379
+ normalized_fc1_param1 = torch.nn.functional.normalize(fc1_param1, dim=-2)
380
+ # normalize fc2
381
+ normalized_fc2_param1 = torch.nn.functional.normalize(fc2_param1, dim=-2)
382
+ # mlp 1
383
+ res_x = x
384
+ x = self.norm(x)
385
+ x = torch.bmm(x, normalized_fc1_param1)
386
+ x = torch.nn.functional.silu(x)
387
+ x = torch.bmm(x, normalized_fc2_param1)
388
+ x = x + res_x
389
+ return x
390
+
391
+ class NerfFinalLayer(nn.Module):
392
+ def __init__(self, hidden_size, out_channels):
393
+ super().__init__()
394
+ self.norm = RMSNorm(hidden_size, eps=1e-6)
395
+ self.linear = nn.Linear(hidden_size, out_channels, bias=True)
396
+ def forward(self, x):
397
+ x = self.norm(x)
398
+ x = self.linear(x)
399
+ return x
400
+
401
+ class PixNerDiT(nn.Module):
402
+ def __init__(
403
+ self,
404
+ in_channels=4,
405
+ num_groups=12,
406
+ hidden_size=1152,
407
+ hidden_size_x=64,
408
+ nerf_mlpratio=4,
409
+ num_blocks=18,
410
+ num_cond_blocks=4,
411
+ patch_size=2,
412
+ num_classes=1000,
413
+ learn_sigma=True,
414
+ deep_supervision=0,
415
+ weight_path=None,
416
+ load_ema=False,
417
+ ):
418
+ super().__init__()
419
+ self.deep_supervision = deep_supervision
420
+ self.learn_sigma = learn_sigma
421
+ self.in_channels = in_channels
422
+ self.out_channels = in_channels
423
+ self.hidden_size = hidden_size
424
+ self.num_groups = num_groups
425
+ self.num_blocks = num_blocks
426
+ self.num_cond_blocks = num_cond_blocks
427
+ self.patch_size = patch_size
428
+ self.x_embedder = NerfEmbedder(in_channels, hidden_size_x, max_freqs=8)
429
+ self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
430
+ self.t_embedder = TimestepEmbedder(hidden_size)
431
+ self.y_embedder = LabelEmbedder(num_classes+1, hidden_size)
432
+
433
+ self.final_layer = NerfFinalLayer(hidden_size_x, self.out_channels)
434
+
435
+ self.weight_path = weight_path
436
+
437
+ self.load_ema = load_ema
438
+ self.blocks = nn.ModuleList([
439
+ FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_cond_blocks)
440
+ ])
441
+ self.blocks.extend([
442
+ NerfBlock(self.hidden_size, hidden_size_x, nerf_mlpratio) for _ in range(self.num_cond_blocks, self.num_blocks)
443
+ ])
444
+ self.initialize_weights()
445
+ self.precompute_pos = dict()
446
+
447
+ def fetch_pos(self, height, width, device):
448
+ if (height, width) in self.precompute_pos:
449
+ return self.precompute_pos[(height, width)].to(device)
450
+ else:
451
+ pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
452
+ self.precompute_pos[(height, width)] = pos
453
+ return pos
454
+
455
+ def initialize_weights(self):
456
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
457
+ w = self.s_embedder.proj.weight.data
458
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
459
+ nn.init.constant_(self.s_embedder.proj.bias, 0)
460
+
461
+ # Initialize label embedding table:
462
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
463
+
464
+ # Initialize timestep embedding MLP:
465
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
466
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
467
+
468
+ # zero init final layer
469
+ nn.init.zeros_(self.final_layer.linear.weight)
470
+ nn.init.zeros_(self.final_layer.linear.bias)
471
+
472
+
473
+ def forward(self, x, t, y, s=None, mask=None):
474
+ B, _, H, W = x.shape
475
+ pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
476
+ x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
477
+ t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
478
+ y = self.y_embedder(y).view(B, 1, self.hidden_size)
479
+ c = nn.functional.silu(t + y)
480
+ if s is None:
481
+ s = self.s_embedder(x)
482
+ for i in range(self.num_cond_blocks):
483
+ s = self.blocks[i](s, c, pos, mask)
484
+ s = nn.functional.silu(t + s)
485
+ batch_size, length, _ = s.shape
486
+ x = x.reshape(batch_size*length, self.in_channels, self.patch_size**2)
487
+ x = x.transpose(1, 2)
488
+ s = s.view(batch_size*length, self.hidden_size)
489
+ x = self.x_embedder(x)
490
+ for i in range(self.num_cond_blocks, self.num_blocks):
491
+ x = self.blocks[i](x, s)
492
+ x = self.final_layer(x)
493
+ x = x.transpose(1, 2)
494
+ x = x.reshape(batch_size, length, -1)
495
+ x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size)
496
+ return x
497
+
498
+
499
+ def to_container(config: Any) -> Any:
500
+ if hasattr(config, "items") and not isinstance(config, dict):
501
+ return {k: to_container(v) for k, v in config.items()}
502
+ if isinstance(config, list):
503
+ return [to_container(v) for v in config]
504
+ return config
505
+
506
+
507
+ def load_symbol(path: str) -> Any:
508
+ module_path, name = path.rsplit(".", 1)
509
+ if module_path in {__name__, "modeling_pixnerd_transformer_2d"}:
510
+ return getattr(sys.modules[__name__], name)
511
+ module = importlib.import_module(module_path)
512
+ return getattr(module, name)
513
+
514
+
515
+ def instantiate_from_spec(spec: Any) -> Any:
516
+ spec = to_container(spec)
517
+ if isinstance(spec, dict) and "class_path" in spec:
518
+ class_or_fn = load_symbol(spec["class_path"])
519
+ init_args = spec.get("init_args", {})
520
+ if isinstance(init_args, dict):
521
+ init_args = {k: instantiate_from_spec(v) for k, v in init_args.items()}
522
+ return class_or_fn(**init_args)
523
+ if isinstance(spec, dict):
524
+ return {k: instantiate_from_spec(v) for k, v in spec.items()}
525
+ if isinstance(spec, list):
526
+ return [instantiate_from_spec(v) for v in spec]
527
+ if isinstance(spec, str) and "." in spec:
528
+ try:
529
+ return load_symbol(spec)
530
+ except Exception:
531
+ return spec
532
+ return spec
533
+
534
+
535
+ def clone_spec(spec: Dict[str, Any]) -> Dict[str, Any]:
536
+ return copy.deepcopy(to_container(spec))
537
+
538
+
539
+ def load_prefixed_state_dict(
540
+ module: Optional[torch.nn.Module],
541
+ state_dict: Dict[str, torch.Tensor],
542
+ prefixes: Iterable[str],
543
+ ) -> bool:
544
+ if module is None:
545
+ return False
546
+ for prefix in prefixes:
547
+ subset = {
548
+ key[len(prefix) :]: value
549
+ for key, value in state_dict.items()
550
+ if key.startswith(prefix)
551
+ }
552
+ if subset:
553
+ module.load_state_dict(subset, strict=False)
554
+ return True
555
+ return False
556
+
557
+
558
+ @dataclass
559
+ class PixNerdTransformer2DModelOutput(BaseOutput):
560
+ sample: torch.FloatTensor
561
+
562
+
563
+ class PixNerdTransformer2DModel(ModelMixin, ConfigMixin):
564
+ config_name = "config.json"
565
+
566
+ @register_to_config
567
+ def __init__(
568
+ self,
569
+ denoiser_spec: Dict[str, Any],
570
+ conditioner_spec: Dict[str, Any],
571
+ vae_spec: Optional[Dict[str, Any]] = None,
572
+ diffusion_trainer_spec: Optional[Dict[str, Any]] = None,
573
+ use_ema: bool = True,
574
+ ema_decay: float = 0.9999,
575
+ compile_denoiser: bool = False,
576
+ ) -> None:
577
+ super().__init__()
578
+ self.denoiser = instantiate_from_spec(to_container(denoiser_spec))
579
+ self.conditioner = instantiate_from_spec(to_container(conditioner_spec))
580
+ self.vae = instantiate_from_spec(to_container(vae_spec)) if vae_spec is not None else None
581
+ self.diffusion_trainer = (
582
+ instantiate_from_spec(to_container(diffusion_trainer_spec))
583
+ if diffusion_trainer_spec is not None
584
+ else None
585
+ )
586
+
587
+ self.use_ema = bool(use_ema)
588
+ self.ema_decay = float(ema_decay)
589
+ self.ema_denoiser = copy.deepcopy(self.denoiser) if self.use_ema else None
590
+ if self.ema_denoiser is not None:
591
+ self.ema_denoiser.to(torch.float32)
592
+
593
+ if compile_denoiser and hasattr(self.denoiser, "compile"):
594
+ self.denoiser.compile()
595
+ if self.ema_denoiser is not None:
596
+ self.ema_denoiser.compile()
597
+
598
+ self._freeze_non_trainable_modules()
599
+ if self.ema_denoiser is not None:
600
+ self.sync_ema()
601
+
602
+ @property
603
+ def patch_size(self) -> int:
604
+ return int(getattr(self.denoiser, "patch_size", 1))
605
+
606
+ @property
607
+ def in_channels(self) -> int:
608
+ return int(getattr(self.denoiser, "in_channels", 3))
609
+
610
+ @classmethod
611
+ def from_project_config(
612
+ cls,
613
+ model_config: Dict[str, Any],
614
+ use_ema: bool = True,
615
+ compile_denoiser: bool = False,
616
+ ) -> "PixNerdTransformer2DModel":
617
+ model_config = to_container(model_config)
618
+ ema_decay = model_config.get("ema_tracker", {}).get("init_args", {}).get("decay", 0.9999)
619
+ return cls(
620
+ denoiser_spec=model_config["denoiser"],
621
+ conditioner_spec=model_config["conditioner"],
622
+ vae_spec=model_config.get("vae"),
623
+ diffusion_trainer_spec=model_config.get("diffusion_trainer"),
624
+ use_ema=use_ema,
625
+ ema_decay=ema_decay,
626
+ compile_denoiser=compile_denoiser,
627
+ )
628
+
629
+ @staticmethod
630
+ def _as_timestep_tensor(
631
+ timestep: Any,
632
+ batch_size: int,
633
+ device: torch.device,
634
+ ) -> torch.Tensor:
635
+ if isinstance(timestep, torch.Tensor):
636
+ if timestep.ndim == 0:
637
+ return timestep.repeat(batch_size).to(device=device, dtype=torch.float32)
638
+ return timestep.to(device=device, dtype=torch.float32)
639
+ return torch.full((batch_size,), float(timestep), device=device, dtype=torch.float32)
640
+
641
+ def _freeze_module(self, module: Optional[torch.nn.Module]) -> None:
642
+ if module is None:
643
+ return
644
+ module.eval()
645
+ for parameter in module.parameters():
646
+ parameter.requires_grad = False
647
+
648
+ def _freeze_non_trainable_modules(self) -> None:
649
+ self._freeze_module(self.conditioner)
650
+ self._freeze_module(self.vae)
651
+ self._freeze_module(self.ema_denoiser)
652
+
653
+ def forward(
654
+ self,
655
+ sample: torch.Tensor,
656
+ timestep: Any,
657
+ encoder_hidden_states: torch.Tensor,
658
+ return_dict: bool = True,
659
+ ) -> PixNerdTransformer2DModelOutput | Tuple[torch.Tensor]:
660
+ t = self._as_timestep_tensor(timestep, sample.shape[0], sample.device)
661
+ out = self.denoiser(sample, t, encoder_hidden_states)
662
+ if not return_dict:
663
+ return (out,)
664
+ return PixNerdTransformer2DModelOutput(sample=out)
665
+
666
+ def predict_noise(
667
+ self,
668
+ sample: torch.Tensor,
669
+ timestep: Any,
670
+ encoder_hidden_states: torch.Tensor,
671
+ use_ema: bool = False,
672
+ ) -> torch.Tensor:
673
+ t = self._as_timestep_tensor(timestep, sample.shape[0], sample.device)
674
+ denoiser = self.get_inference_denoiser(use_ema=use_ema)
675
+ return denoiser(sample, t, encoder_hidden_states)
676
+
677
+ def get_inference_denoiser(self, use_ema: bool = True) -> torch.nn.Module:
678
+ if use_ema and self.ema_denoiser is not None:
679
+ return self.ema_denoiser
680
+ return self.denoiser
681
+
682
+ @torch.no_grad()
683
+ def get_conditioning(
684
+ self,
685
+ y: Iterable[Any],
686
+ metadata: Optional[Dict[str, Any]] = None,
687
+ ):
688
+ metadata = {} if metadata is None else metadata
689
+ return self.conditioner(y, metadata)
690
+
691
+ @torch.no_grad()
692
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
693
+ if self.vae is None:
694
+ return x
695
+ return self.vae.encode(x)
696
+
697
+ @torch.no_grad()
698
+ def decode(self, latents: torch.Tensor) -> torch.Tensor:
699
+ if self.vae is None:
700
+ return latents
701
+ return self.vae.decode(latents)
702
+
703
+ @torch.no_grad()
704
+ def sync_ema(self) -> None:
705
+ if self.ema_denoiser is None:
706
+ return
707
+ self.ema_denoiser.load_state_dict(self.denoiser.state_dict(), strict=True)
708
+ self.ema_denoiser.to(torch.float32)
709
+
710
+ @torch.no_grad()
711
+ def ema_step(self, decay: Optional[float] = None) -> None:
712
+ if self.ema_denoiser is None:
713
+ return
714
+ decay = self.ema_decay if decay is None else float(decay)
715
+ for ema_param, param in zip(self.ema_denoiser.parameters(), self.denoiser.parameters()):
716
+ ema_param.mul_(decay).add_(param.detach().float(), alpha=1.0 - decay)
717
+
718
+ def compute_training_loss(
719
+ self,
720
+ x: torch.Tensor,
721
+ y: Iterable[Any],
722
+ scheduler: torch.nn.Module,
723
+ metadata: Optional[Dict[str, Any]] = None,
724
+ ) -> Dict[str, torch.Tensor]:
725
+ if self.diffusion_trainer is None:
726
+ raise RuntimeError("diffusion_trainer is not configured.")
727
+ metadata = {} if metadata is None else metadata
728
+
729
+ with torch.no_grad():
730
+ x = self.encode(x)
731
+ condition, uncondition = self.get_conditioning(y, metadata)
732
+
733
+ return self.diffusion_trainer(
734
+ self.denoiser,
735
+ self.ema_denoiser if self.ema_denoiser is not None else self.denoiser,
736
+ scheduler,
737
+ x,
738
+ condition,
739
+ uncondition,
740
+ metadata,
741
+ )
742
+
743
+ __all__ = [
744
+ "PixNerDiT",
745
+ "LabelConditioner",
746
+ "PixelAE",
747
+ "PixNerdTransformer2DModel",
748
+ "PixNerdTransformer2DModelOutput",
749
+ ]
README.md CHANGED
@@ -11,71 +11,129 @@ language:
11
  - en
12
  ---
13
 
14
- # PixNerd-XL-16 Diffusers Checkpoints
15
 
16
- Production-ready Diffusers export of PixNerd-XL/16 class-conditional ImageNet checkpoints.
17
 
18
- ## Available Checkpoints
19
 
20
- - `PixNerd-XL-16-256`
21
- - source: `epoch%3D319-step%3D1600000_emainit.ckpt`
22
- - target resolution: `256x256`
23
- - `PixNerd-XL-16-512`
24
- - source: `res512_ft200k_epoch%3D325-step%3D1800000_emainit.ckpt`
25
- - target resolution: `512x512`
26
 
27
- Both checkpoints are packaged with:
28
 
29
- - `pipeline.py`
30
- - `modeling_pixnerd_transformer_2d.py`
31
- - `scheduling_pixnerd_flow_match.py`
32
- - `transformer/` weights + config
33
- - `scheduler/` config
34
 
35
- ## Requirements
 
 
 
36
 
37
- ```bash
38
- pip install torch diffusers
39
- ```
 
 
 
 
 
 
 
 
 
40
 
41
- ## Inference (Python)
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  ```python
44
  import torch
45
  from diffusers import DiffusionPipeline
46
 
47
- model_dir = "PixNerd-XL-16-256" # or PixNerd-XL-16-512
 
 
48
  pipe = DiffusionPipeline.from_pretrained(
49
- model_dir,
50
- custom_pipeline=f"{model_dir}/pipeline.py",
51
  torch_dtype=torch.float32,
52
- ).to("cpu") # use "cuda" if available
53
 
54
- # Class-conditional generation: class label 207 (golden retriever)
55
  images = pipe(
56
- prompt=[207],
57
- num_images_per_prompt=1,
58
- height=256,
59
- width=256,
60
  num_inference_steps=25,
61
  guidance_scale=4.0,
62
  timeshift=3.0,
63
  order=2,
64
  ).images
65
 
66
- images[0].save("sample.png")
 
 
 
67
  ```
68
 
69
- ## Interface Notes
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
- - The pipeline uses `prompt` for conditioning input.
72
- - For class-conditional generation, pass integer labels, e.g. `prompt=[207]`.
73
- - `height` and `width` should match checkpoint intent (256 or 512), but custom sizes work if divisible by patch size.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
- ## Reproducibility Metadata
76
 
 
 
 
77
  - Architecture and conversion provenance are recorded in each checkpoint's `conversion_metadata.json`.
78
- - Transformer and scheduler runtime classes are defined in repository-local Python modules shipped with each checkpoint.
79
 
80
  ## Limitations
81
 
 
11
  - en
12
  ---
13
 
14
+ # BiliSakura/PixNerd-diffusers
15
 
16
+ Self-contained PixNerd-XL/16 checkpoints for Hugging Face diffusers. **No external code repo is required** — each subfolder ships its own `pipeline.py`, component modules, and weights.
17
 
18
+ This repo is derived from the development bundle in [Visual-Generative-Foundation-Model-Collection](https://github.com/Bili-Sakura/Visual-Generative-Foundation-Model-Collection), but inference only needs:
19
 
20
+ - This model repo (`BiliSakura/PixNerd-diffusers`)
21
+ - PyPI `diffusers`, `torch`, `huggingface_hub`
 
 
 
 
22
 
23
+ This Hugging Face repo hosts **multiple self-contained checkpoints as subfolders**. Each subfolder includes its own `pipeline.py`, `model_index.json`, weights, and component code (`transformer/`, `scheduler/`).
24
 
25
+ ## Available checkpoints
 
 
 
 
26
 
27
+ | Subfolder | Resolution | Source checkpoint |
28
+ | --- | --- | --- |
29
+ | [`PixNerd-XL-16-256/`](PixNerd-XL-16-256/) | 256×256 | `epoch%3D319-step%3D1600000_emainit.ckpt` |
30
+ | [`PixNerd-XL-16-512/`](PixNerd-XL-16-512/) | 512×512 | `res512_ft200k_epoch%3D325-step%3D1800000_emainit.ckpt` |
31
 
32
+ Both checkpoints are ImageNet class-conditional PixNerd-XL/16 exports with flow-matching sampling.
33
+
34
+ ## ImageNet class labels
35
+
36
+ ImageNet-1k labels live in shared [`labels/`](labels/) at the repo root (not duplicated per variant). Format follows Hugging Face / DiT convention:
37
+
38
+ | File | Direction | Value format |
39
+ | --- | --- | --- |
40
+ | `labels/id2label_en.json` | id → English | comma-separated synonyms, e.g. `"207": "golden retriever"` |
41
+ | `labels/id2label_cn.json` | id → Chinese | comma-separated synonyms, e.g. `"207": "金毛猎犬"` |
42
+
43
+ After `PixNerdPipeline.from_pretrained(...)`, the pipeline exposes:
44
 
45
+ - `pipe.id2label` / `pipe.id2label_cn` — inspect id → label correspondence
46
+ - `pipe.labels` / `pipe.labels_cn` — reverse maps (synonym → id), sorted for browsing
47
+ - `pipe.get_label_ids("golden retriever")` or `pipe.get_label_ids("金毛猎犬", lang="cn")`
48
+ - `pipe(prompt="golden retriever", ...)` — string labels resolved automatically
49
+
50
+ Why JSON at repo root instead of a Python dict in each variant?
51
+
52
+ 1. **Explicit correspondence** — users can open `id2label_en.json` and see every id without running code.
53
+ 2. **Hub-compatible** — same shape as `facebook/DiT-XL-2-256` and other vision checkpoints.
54
+ 3. **Shared across variants** — both PixNerd checkpoints use the same 1000 ImageNet classes.
55
+ 4. **Bilingual without duplication** — English and Chinese are separate files; the pipeline loads both.
56
+
57
+ ## Load from Hugging Face
58
 
59
  ```python
60
  import torch
61
  from diffusers import DiffusionPipeline
62
 
63
+ variant = "PixNerd-XL-16-256" # or PixNerd-XL-16-512
64
+ resolution = 256 if variant.endswith("256") else 512
65
+
66
  pipe = DiffusionPipeline.from_pretrained(
67
+ f"BiliSakura/PixNerd-diffusers/{variant}",
68
+ trust_remote_code=True,
69
  torch_dtype=torch.float32,
70
+ ).to("cuda")
71
 
 
72
  images = pipe(
73
+ prompt=207,
74
+ height=resolution,
75
+ width=resolution,
 
76
  num_inference_steps=25,
77
  guidance_scale=4.0,
78
  timeshift=3.0,
79
  order=2,
80
  ).images
81
 
82
+ print(pipe.id2label[207]) # "golden retriever"
83
+ print(pipe.id2label_cn[207]) # "金毛猎犬"
84
+ pipe.get_label_ids("golden retriever") # [207]
85
+ images = pipe(prompt="golden retriever", height=resolution, width=resolution).images
86
  ```
87
 
88
+ ## Load from a local clone
89
+
90
+ ```python
91
+ import torch
92
+ from diffusers import DiffusionPipeline
93
+
94
+ repo = "models/BiliSakura/PixNerd-diffusers"
95
+ variant = "PixNerd-XL-16-256"
96
+
97
+ pipe = DiffusionPipeline.from_pretrained(
98
+ f"{repo}/{variant}",
99
+ trust_remote_code=True,
100
+ torch_dtype=torch.float32,
101
+ ).to("cuda")
102
 
103
+ images = pipe(prompt=207, height=256, width=256).images
104
+ ```
105
+
106
+ ## Repo layout
107
+
108
+ ```text
109
+ BiliSakura/PixNerd-diffusers/
110
+ ├── README.md
111
+ ├── labels/
112
+ │ ├── id2label_en.json # ImageNet id -> English synonyms
113
+ │ ├── id2label_cn.json # ImageNet id -> Chinese synonyms
114
+ │ └── imagenet_labels.py # loader helpers
115
+ ├── PixNerd-XL-16-256/
116
+ │ ├── README.md
117
+ │ ├── pipeline.py
118
+ │ ├── model_index.json
119
+ │ ├── conversion_metadata.json
120
+ │ ├── transformer/
121
+ │ └── scheduler/
122
+ └── PixNerd-XL-16-512/
123
+ ├── README.md
124
+ ├── pipeline.py
125
+ ├── model_index.json
126
+ ├── conversion_metadata.json
127
+ ├── transformer/
128
+ └── scheduler/
129
+ ```
130
 
131
+ ## Interface notes
132
 
133
+ - The pipeline uses `prompt` for class conditioning input.
134
+ - Pass integer ImageNet ids (`prompt=207`) or human-readable synonyms (`prompt="golden retriever"`).
135
+ - `height` and `width` should match checkpoint intent (256 or 512), but custom sizes work if divisible by patch size (16).
136
  - Architecture and conversion provenance are recorded in each checkpoint's `conversion_metadata.json`.
 
137
 
138
  ## Limitations
139
 
labels/id2label_cn.json ADDED
@@ -0,0 +1,1002 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "0": "丁鲷",
3
+ "1": "金鱼",
4
+ "2": "大白鲨",
5
+ "3": "虎鲨",
6
+ "4": "锤头鲨",
7
+ "5": "电鳐",
8
+ "6": "黄貂鱼",
9
+ "7": "公鸡",
10
+ "8": "母鸡",
11
+ "9": "鸵鸟",
12
+ "10": "燕雀",
13
+ "11": "金翅雀",
14
+ "12": "家朱雀",
15
+ "13": "灯芯草雀",
16
+ "14": "靛蓝雀,靛蓝鸟",
17
+ "15": "蓝鹀",
18
+ "16": "夜莺",
19
+ "17": "松鸦",
20
+ "18": "喜鹊",
21
+ "19": "山雀",
22
+ "20": "河鸟",
23
+ "21": "鸢(猛禽)",
24
+ "22": "秃头鹰",
25
+ "23": "秃鹫",
26
+ "24": "大灰猫头鹰",
27
+ "25": "欧洲火蝾螈",
28
+ "26": "普通蝾螈",
29
+ "27": "水蜥",
30
+ "28": "斑点蝾螈",
31
+ "29": "蝾螈,泥狗",
32
+ "30": "牛蛙",
33
+ "31": "树蛙",
34
+ "32": "尾蛙,铃蟾蜍,肋蟾蜍,尾蟾蜍",
35
+ "33": "红海龟",
36
+ "34": "皮革龟",
37
+ "35": "泥龟",
38
+ "36": "淡水龟",
39
+ "37": "箱龟",
40
+ "38": "带状壁虎",
41
+ "39": "普通鬣蜥",
42
+ "40": "美国变色龙",
43
+ "41": "鞭尾蜥蜴",
44
+ "42": "飞龙科蜥蜴",
45
+ "43": "褶边蜥蜴",
46
+ "44": "鳄鱼蜥蜴",
47
+ "45": "毒蜥",
48
+ "46": "绿蜥蜴",
49
+ "47": "非洲变色龙",
50
+ "48": "科莫多蜥蜴",
51
+ "49": "非洲鳄,尼罗河鳄鱼",
52
+ "50": "美国鳄鱼,鳄鱼",
53
+ "51": "三角龙",
54
+ "52": "雷蛇,蠕虫蛇",
55
+ "53": "环蛇,环颈蛇",
56
+ "54": "希腊蛇",
57
+ "55": "绿蛇,草蛇",
58
+ "56": "国王蛇",
59
+ "57": "袜带蛇,草蛇",
60
+ "58": "水蛇",
61
+ "59": "藤蛇",
62
+ "60": "夜蛇",
63
+ "61": "大蟒蛇",
64
+ "62": "岩石蟒蛇,岩蛇,蟒蛇",
65
+ "63": "印度眼镜蛇",
66
+ "64": "绿曼巴",
67
+ "65": "海蛇",
68
+ "66": "角腹蛇",
69
+ "67": "菱纹响尾蛇",
70
+ "68": "角响尾蛇",
71
+ "69": "三叶虫",
72
+ "70": "盲蜘蛛",
73
+ "71": "蝎子",
74
+ "72": "黑金花园蜘蛛",
75
+ "73": "谷仓蜘蛛",
76
+ "74": "花园蜘蛛",
77
+ "75": "黑寡妇蜘蛛",
78
+ "76": "狼蛛",
79
+ "77": "狼蜘蛛,狩猎蜘蛛",
80
+ "78": "壁虱",
81
+ "79": "蜈蚣",
82
+ "80": "黑松鸡",
83
+ "81": "松鸡,雷鸟",
84
+ "82": "披肩鸡,披肩榛鸡",
85
+ "83": "草原鸡,草原松鸡",
86
+ "84": "孔雀",
87
+ "85": "鹌鹑",
88
+ "86": "鹧鸪",
89
+ "87": "非洲灰鹦鹉",
90
+ "88": "金刚鹦鹉",
91
+ "89": "硫冠鹦鹉",
92
+ "90": "短尾鹦鹉",
93
+ "91": "褐翅鸦鹃",
94
+ "92": "蜜蜂",
95
+ "93": "犀鸟",
96
+ "94": "蜂鸟",
97
+ "95": "鹟䴕",
98
+ "96": "犀鸟",
99
+ "97": "野鸭",
100
+ "98": "红胸秋沙鸭",
101
+ "99": "鹅",
102
+ "100": "黑天鹅",
103
+ "101": "大象",
104
+ "102": "针鼹鼠",
105
+ "103": "鸭嘴兽",
106
+ "104": "沙袋鼠",
107
+ "105": "考拉,考拉熊",
108
+ "106": "袋熊",
109
+ "107": "水母",
110
+ "108": "海葵",
111
+ "109": "脑珊瑚",
112
+ "110": "扁形虫扁虫",
113
+ "111": "线虫,蛔虫",
114
+ "112": "海螺",
115
+ "113": "蜗牛",
116
+ "114": "鼻涕虫",
117
+ "115": "海参",
118
+ "116": "石鳖",
119
+ "117": "鹦鹉螺",
120
+ "118": "珍宝蟹",
121
+ "119": "石蟹",
122
+ "120": "招潮蟹",
123
+ "121": "帝王蟹,阿拉斯加蟹,阿拉斯加帝王蟹",
124
+ "122": "美国龙虾,缅因州龙虾",
125
+ "123": "大螯虾",
126
+ "124": "小龙虾",
127
+ "125": "寄居蟹",
128
+ "126": "等足目动物(明虾和螃蟹近亲)",
129
+ "127": "白鹳",
130
+ "128": "黑鹳",
131
+ "129": "鹭",
132
+ "130": "火烈鸟",
133
+ "131": "小蓝鹭",
134
+ "132": "美国鹭,大白鹭",
135
+ "133": "麻鸦",
136
+ "134": "鹤",
137
+ "135": "秧鹤",
138
+ "136": "欧洲水鸡,紫水鸡",
139
+ "137": "沼泽泥母鸡,水母鸡",
140
+ "138": "鸨",
141
+ "139": "红翻石鹬",
142
+ "140": "红背鹬,黑腹滨鹬",
143
+ "141": "红脚鹬",
144
+ "142": "半蹼鹬",
145
+ "143": "蛎鹬",
146
+ "144": "鹈鹕",
147
+ "145": "国王企鹅",
148
+ "146": "信天翁,大海鸟",
149
+ "147": "灰鲸",
150
+ "148": "杀人鲸,逆戟鲸,虎鲸",
151
+ "149": "海牛",
152
+ "150": "海狮",
153
+ "151": "奇瓦瓦",
154
+ "152": "日本猎犬",
155
+ "153": "马尔济斯犬",
156
+ "154": "狮子狗",
157
+ "155": "西施犬",
158
+ "156": "布莱尼姆猎犬",
159
+ "157": "巴比狗",
160
+ "158": "玩具犬",
161
+ "159": "罗得西亚长背猎狗",
162
+ "160": "阿富汗猎犬",
163
+ "161": "猎犬",
164
+ "162": "比格犬,猎兔犬",
165
+ "163": "侦探犬",
166
+ "164": "蓝色快狗",
167
+ "165": "黑褐猎浣熊犬",
168
+ "166": "沃克猎犬",
169
+ "167": "英国猎狐犬",
170
+ "168": "美洲赤狗",
171
+ "169": "俄罗斯猎狼犬",
172
+ "170": "爱尔兰猎狼犬",
173
+ "171": "意大利灰狗",
174
+ "172": "惠比特犬",
175
+ "173": "依比沙猎犬",
176
+ "174": "挪威猎犬",
177
+ "175": "奥达猎犬,水獭猎犬",
178
+ "176": "沙克犬,瞪羚猎犬",
179
+ "177": "苏格兰猎鹿犬,猎鹿犬",
180
+ "178": "威玛猎犬",
181
+ "179": "斯塔福德郡牛头梗,斯塔福德郡斗牛梗",
182
+ "180": "美国斯塔福德郡梗,美国比特斗牛梗,斗牛梗",
183
+ "181": "贝德灵顿梗",
184
+ "182": "边境梗",
185
+ "183": "凯丽蓝梗",
186
+ "184": "爱尔兰梗",
187
+ "185": "诺福克梗",
188
+ "186": "诺维奇梗",
189
+ "187": "约克郡梗",
190
+ "188": "刚毛猎狐梗",
191
+ "189": "莱克兰梗",
192
+ "190": "锡利哈姆梗",
193
+ "191": "艾尔谷犬",
194
+ "192": "凯恩梗",
195
+ "193": "澳大利亚梗",
196
+ "194": "丹迪丁蒙梗",
197
+ "195": "波士顿梗",
198
+ "196": "迷你雪纳瑞犬",
199
+ "197": "巨型雪纳瑞犬",
200
+ "198": "标准雪纳瑞犬",
201
+ "199": "苏格兰梗",
202
+ "200": "西藏梗,菊花狗",
203
+ "201": "丝毛梗",
204
+ "202": "软毛麦色梗",
205
+ "203": "西高地白梗",
206
+ "204": "拉萨阿普索犬",
207
+ "205": "平毛寻回犬",
208
+ "206": "卷毛寻回犬",
209
+ "207": "金毛猎犬",
210
+ "208": "拉布拉多猎犬",
211
+ "209": "乞沙比克猎犬",
212
+ "210": "德国短毛猎犬",
213
+ "211": "维兹拉犬",
214
+ "212": "英国谍犬",
215
+ "213": "爱尔兰雪达犬,红色猎犬",
216
+ "214": "戈登雪达犬",
217
+ "215": "布列塔尼犬猎犬",
218
+ "216": "黄毛,黄毛猎犬",
219
+ "217": "英国史宾格犬",
220
+ "218": "威尔士史宾格犬",
221
+ "219": "可卡犬,英国可卡犬",
222
+ "220": "萨塞克斯猎犬",
223
+ "221": "爱尔兰水猎犬",
224
+ "222": "哥威斯犬",
225
+ "223": "舒柏奇犬",
226
+ "224": "比利时牧羊犬",
227
+ "225": "马里努阿犬",
228
+ "226": "伯瑞犬",
229
+ "227": "凯尔皮犬",
230
+ "228": "匈牙利牧羊犬",
231
+ "229": "老英国牧羊犬",
232
+ "230": "喜乐蒂牧羊犬",
233
+ "231": "牧羊犬",
234
+ "232": "边境牧羊犬",
235
+ "233": "法兰德斯牧牛狗",
236
+ "234": "罗特韦尔犬",
237
+ "235": "德国牧羊犬,德国警犬,阿尔萨斯",
238
+ "236": "多伯曼犬,杜宾犬",
239
+ "237": "迷你杜宾犬",
240
+ "238": "大瑞士山地犬",
241
+ "239": "伯恩山犬",
242
+ "240": "Appenzeller狗",
243
+ "241": "EntleBucher狗",
244
+ "242": "拳师狗",
245
+ "243": "斗牛獒",
246
+ "244": "藏獒",
247
+ "245": "法国斗牛犬",
248
+ "246": "大丹犬",
249
+ "247": "圣伯纳德狗",
250
+ "248": "爱斯基摩犬,哈士奇",
251
+ "249": "雪橇犬,阿拉斯加爱斯基摩狗",
252
+ "250": "哈士奇",
253
+ "251": "达尔马提亚,教练车狗",
254
+ "252": "狮毛狗",
255
+ "253": "巴辛吉狗",
256
+ "254": "哈巴狗,狮子狗",
257
+ "255": "莱昂贝格狗",
258
+ "256": "纽芬兰岛狗",
259
+ "257": "大白熊犬",
260
+ "258": "萨摩耶犬",
261
+ "259": "博美犬",
262
+ "260": "松狮,松狮",
263
+ "261": "荷兰卷尾狮毛狗",
264
+ "262": "布鲁塞尔格林芬犬",
265
+ "263": "彭布洛克威尔士科基犬",
266
+ "264": "威尔士柯基犬",
267
+ "265": "玩具贵宾犬",
268
+ "266": "迷你贵宾犬",
269
+ "267": "标准贵宾犬",
270
+ "268": "墨西哥无毛犬",
271
+ "269": "灰狼",
272
+ "270": "白狼,北极狼",
273
+ "271": "红太狼,鬃狼,犬犬鲁弗斯",
274
+ "272": "狼,草原狼,刷狼,郊狼",
275
+ "273": "澳洲野狗,澳大利亚野犬",
276
+ "274": "豺",
277
+ "275": "非洲猎犬,土狼犬",
278
+ "276": "鬣狗",
279
+ "277": "红狐狸",
280
+ "278": "沙狐",
281
+ "279": "北极狐狸,白狐狸",
282
+ "280": "灰狐狸",
283
+ "281": "虎斑猫",
284
+ "282": "山猫,虎猫",
285
+ "283": "波斯猫",
286
+ "284": "暹罗暹罗猫,",
287
+ "285": "埃及猫",
288
+ "286": "美洲狮,美洲豹",
289
+ "287": "猞猁,山猫",
290
+ "288": "豹子",
291
+ "289": "雪豹",
292
+ "290": "美洲虎",
293
+ "291": "狮子",
294
+ "292": "老虎",
295
+ "293": "猎豹",
296
+ "294": "棕熊",
297
+ "295": "美洲黑熊",
298
+ "296": "冰熊,北极熊",
299
+ "297": "懒熊",
300
+ "298": "猫鼬",
301
+ "299": "猫鼬,海猫",
302
+ "300": "虎甲虫",
303
+ "301": "瓢虫",
304
+ "302": "土鳖虫",
305
+ "303": "天牛",
306
+ "304": "龟甲虫",
307
+ "305": "粪甲虫",
308
+ "306": "犀牛甲虫",
309
+ "307": "象甲",
310
+ "308": "苍蝇",
311
+ "309": "蜜蜂",
312
+ "310": "蚂蚁",
313
+ "311": "蚱蜢",
314
+ "312": "蟋蟀",
315
+ "313": "竹节虫",
316
+ "314": "蟑螂",
317
+ "315": "螳螂",
318
+ "316": "蝉",
319
+ "317": "叶蝉",
320
+ "318": "草蜻蛉",
321
+ "319": "蜻蜓",
322
+ "320": "豆娘,蜻蛉",
323
+ "321": "优红蛱蝶",
324
+ "322": "小环蝴蝶",
325
+ "323": "君主蝴蝶,大斑蝶",
326
+ "324": "菜粉蝶",
327
+ "325": "白蝴蝶",
328
+ "326": "灰蝶",
329
+ "327": "海星",
330
+ "328": "海胆",
331
+ "329": "海参,海黄瓜",
332
+ "330": "野兔",
333
+ "331": "兔",
334
+ "332": "安哥拉兔",
335
+ "333": "仓鼠",
336
+ "334": "刺猬,豪猪,",
337
+ "335": "黑松鼠",
338
+ "336": "土拨鼠",
339
+ "337": "海狸",
340
+ "338": "豚鼠,豚鼠",
341
+ "339": "栗色马",
342
+ "340": "斑马",
343
+ "341": "猪",
344
+ "342": "野猪",
345
+ "343": "疣猪",
346
+ "344": "河马",
347
+ "345": "牛",
348
+ "346": "水牛,亚洲水牛",
349
+ "347": "野牛",
350
+ "348": "公羊",
351
+ "349": "大角羊,洛矶山大角羊",
352
+ "350": "山羊",
353
+ "351": "狷羚",
354
+ "352": "黑斑羚",
355
+ "353": "瞪羚",
356
+ "354": "阿拉伯单峰骆驼,骆驼",
357
+ "355": "羊驼",
358
+ "356": "黄鼠狼",
359
+ "357": "水貂",
360
+ "358": "臭猫",
361
+ "359": "黑足鼬",
362
+ "360": "水獭",
363
+ "361": "臭鼬,木猫",
364
+ "362": "獾",
365
+ "363": "犰狳",
366
+ "364": "树懒",
367
+ "365": "猩猩,婆罗洲猩猩",
368
+ "366": "大猩猩",
369
+ "367": "黑猩猩",
370
+ "368": "长臂猿",
371
+ "369": "合趾猿长臂猿,合趾猿",
372
+ "370": "长尾猴",
373
+ "371": "赤猴",
374
+ "372": "狒狒",
375
+ "373": "恒河猴,猕猴",
376
+ "374": "白头叶猴",
377
+ "375": "疣猴",
378
+ "376": "长鼻猴",
379
+ "377": "狨(美洲产小型长尾猴)",
380
+ "378": "卷尾猴",
381
+ "379": "吼猴",
382
+ "380": "伶猴",
383
+ "381": "蜘蛛猴",
384
+ "382": "松鼠猴",
385
+ "383": "马达加斯加环尾狐猴,鼠狐猴",
386
+ "384": "大狐猴,马达加斯加大狐猴",
387
+ "385": "印度大象,亚洲象",
388
+ "386": "非洲象,非洲象",
389
+ "387": "小熊猫",
390
+ "388": "大熊猫",
391
+ "389": "杖鱼",
392
+ "390": "鳗鱼",
393
+ "391": "银鲑,银鲑���",
394
+ "392": "三色刺蝶鱼",
395
+ "393": "海葵鱼",
396
+ "394": "鲟鱼",
397
+ "395": "雀鳝",
398
+ "396": "狮子鱼",
399
+ "397": "河豚",
400
+ "398": "算盘",
401
+ "399": "长袍",
402
+ "400": "学位袍",
403
+ "401": "手风琴",
404
+ "402": "原声吉他",
405
+ "403": "航空母舰",
406
+ "404": "客机",
407
+ "405": "飞艇",
408
+ "406": "祭坛",
409
+ "407": "救护车",
410
+ "408": "水陆两用车",
411
+ "409": "模拟时钟",
412
+ "410": "蜂房",
413
+ "411": "围裙",
414
+ "412": "垃圾桶",
415
+ "413": "攻击步枪,枪",
416
+ "414": "背包",
417
+ "415": "面包店,面包铺,",
418
+ "416": "平衡木",
419
+ "417": "热气球",
420
+ "418": "圆珠笔",
421
+ "419": "创可贴",
422
+ "420": "班卓琴",
423
+ "421": "栏杆,楼梯扶手",
424
+ "422": "杠铃",
425
+ "423": "理发师的椅子",
426
+ "424": "理发店",
427
+ "425": "牲口棚",
428
+ "426": "晴雨表",
429
+ "427": "圆筒",
430
+ "428": "园地小车,手推车",
431
+ "429": "棒球",
432
+ "430": "篮球",
433
+ "431": "婴儿床",
434
+ "432": "巴松管,低音管",
435
+ "433": "游泳帽",
436
+ "434": "沐浴毛巾",
437
+ "435": "浴缸,澡盆",
438
+ "436": "沙滩车,旅行车",
439
+ "437": "灯塔",
440
+ "438": "高脚杯",
441
+ "439": "熊皮高帽",
442
+ "440": "啤酒瓶",
443
+ "441": "啤酒杯",
444
+ "442": "钟塔",
445
+ "443": "(小儿用的)围嘴",
446
+ "444": "串联自行车,",
447
+ "445": "比基尼",
448
+ "446": "装订册",
449
+ "447": "双筒望远镜",
450
+ "448": "鸟舍",
451
+ "449": "船库",
452
+ "450": "雪橇",
453
+ "451": "饰扣式领带",
454
+ "452": "阔边女帽",
455
+ "453": "书橱",
456
+ "454": "书店,书摊",
457
+ "455": "瓶盖",
458
+ "456": "弓箭",
459
+ "457": "蝴蝶结领结",
460
+ "458": "铜制牌位",
461
+ "459": "奶罩",
462
+ "460": "防波堤,海堤",
463
+ "461": "铠甲",
464
+ "462": "扫帚",
465
+ "463": "桶",
466
+ "464": "扣环",
467
+ "465": "防弹背心",
468
+ "466": "动车,子弹头列车",
469
+ "467": "肉铺,肉菜市场",
470
+ "468": "出租车",
471
+ "469": "大锅",
472
+ "470": "蜡烛",
473
+ "471": "大炮",
474
+ "472": "独木舟",
475
+ "473": "开瓶器,开罐器",
476
+ "474": "开衫",
477
+ "475": "车镜",
478
+ "476": "旋转木马",
479
+ "477": "木匠的工具包,工具包",
480
+ "478": "纸箱",
481
+ "479": "车轮",
482
+ "480": "取款机,自动取款机",
483
+ "481": "盒式录音带",
484
+ "482": "卡带播放器",
485
+ "483": "城堡",
486
+ "484": "双体船",
487
+ "485": "CD播放器",
488
+ "486": "大提琴",
489
+ "487": "移动电话,手机",
490
+ "488": "铁链",
491
+ "489": "围栏",
492
+ "490": "链甲",
493
+ "491": "电锯,油锯",
494
+ "492": "箱子",
495
+ "493": "衣柜,洗脸台",
496
+ "494": "编钟,钟,锣",
497
+ "495": "中国橱柜",
498
+ "496": "圣诞袜",
499
+ "497": "教堂,教堂建筑",
500
+ "498": "电影院,剧场",
501
+ "499": "切肉刀,菜刀",
502
+ "500": "悬崖屋",
503
+ "501": "斗篷",
504
+ "502": "木屐,木鞋",
505
+ "503": "鸡尾酒调酒器",
506
+ "504": "咖啡杯",
507
+ "505": "咖啡壶",
508
+ "506": "螺旋结构(楼梯)",
509
+ "507": "组合锁",
510
+ "508": "电脑键盘,键盘",
511
+ "509": "糖果,糖果店",
512
+ "510": "集装箱船",
513
+ "511": "敞篷车",
514
+ "512": "开瓶器,瓶螺杆",
515
+ "513": "短号,喇叭",
516
+ "514": "牛仔靴",
517
+ "515": "牛仔帽",
518
+ "516": "摇篮",
519
+ "517": "起重机",
520
+ "518": "头盔",
521
+ "519": "板条箱",
522
+ "520": "小儿床",
523
+ "521": "砂锅",
524
+ "522": "槌球",
525
+ "523": "拐杖",
526
+ "524": "胸甲",
527
+ "525": "大坝,堤防",
528
+ "526": "书桌",
529
+ "527": "台式电脑",
530
+ "528": "有线电话",
531
+ "529": "尿布湿",
532
+ "530": "数字时钟",
533
+ "531": "数字手表",
534
+ "532": "餐桌板",
535
+ "533": "抹布",
536
+ "534": "洗碗机,洗碟机",
537
+ "535": "盘式制动器",
538
+ "536": "码头,船坞,码头设施",
539
+ "537": "狗拉雪橇",
540
+ "538": "圆顶",
541
+ "539": "门垫,垫子",
542
+ "540": "钻井平台,海上钻井",
543
+ "541": "鼓,乐器,鼓膜",
544
+ "542": "鼓槌",
545
+ "543": "哑铃",
546
+ "544": "荷兰烤箱",
547
+ "545": "电风扇,鼓风机",
548
+ "546": "电吉他",
549
+ "547": "电力机车",
550
+ "548": "电视,电视柜",
551
+ "549": "信封",
552
+ "550": "浓缩咖啡机",
553
+ "551": "扑面粉",
554
+ "552": "女用长围巾",
555
+ "553": "文件,文件柜,档案柜",
556
+ "554": "消防船",
557
+ "555": "消防车",
558
+ "556": "火炉栏",
559
+ "557": "旗杆",
560
+ "558": "长笛",
561
+ "559": "折叠椅",
562
+ "560": "橄榄球头盔",
563
+ "561": "叉车",
564
+ "562": "喷泉",
565
+ "563": "钢笔",
566
+ "564": "有四根帷柱的床",
567
+ "565": "运货车厢",
568
+ "566": "圆号,喇叭",
569
+ "567": "煎锅",
570
+ "568": "裘皮大衣",
571
+ "569": "垃圾车",
572
+ "570": "防毒面具,呼吸器",
573
+ "571": "汽油泵",
574
+ "572": "高脚杯",
575
+ "573": "卡丁车",
576
+ "574": "高尔夫球",
577
+ "575": "高尔夫球车",
578
+ "576": "狭长小船",
579
+ "577": "锣",
580
+ "578": "礼服",
581
+ "579": "钢琴",
582
+ "580": "温室,苗圃",
583
+ "581": "散热器格栅",
584
+ "582": "杂货店,食品市场",
585
+ "583": "断头台",
586
+ "584": "小发夹",
587
+ "585": "头发喷雾",
588
+ "586": "半履带装甲车",
589
+ "587": "锤子",
590
+ "588": "大篮子",
591
+ "589": "手摇鼓风机,吹风机",
592
+ "590": "手提电脑",
593
+ "591": "手帕",
594
+ "592": "硬盘",
595
+ "593": "口琴,口风琴",
596
+ "594": "竖琴",
597
+ "595": "收割机",
598
+ "596": "斧头",
599
+ "597": "手枪皮套",
600
+ "598": "家庭影院",
601
+ "599": "蜂窝",
602
+ "600": "钩爪",
603
+ "601": "衬裙",
604
+ "602": "单杠",
605
+ "603": "马车",
606
+ "604": "沙漏",
607
+ "605": "手机,iPad",
608
+ "606": "熨斗",
609
+ "607": "南瓜灯笼",
610
+ "608": "牛仔裤,蓝色牛仔裤",
611
+ "609": "吉普车",
612
+ "610": "运动衫,T恤",
613
+ "611": "拼图",
614
+ "612": "人力车",
615
+ "613": "操纵杆",
616
+ "614": "和服",
617
+ "615": "护膝",
618
+ "616": "蝴蝶结",
619
+ "617": "大褂,实验室外套",
620
+ "618": "长柄勺",
621
+ "619": "灯罩",
622
+ "620": "笔记本电脑",
623
+ "621": "割草机",
624
+ "622": "镜头盖",
625
+ "623": "开信刀,裁纸刀",
626
+ "624": "图书馆",
627
+ "625": "救生艇",
628
+ "626": "点火器,打火机",
629
+ "627": "豪华轿车",
630
+ "628": "远洋班轮",
631
+ "629": "唇膏,口红",
632
+ "630": "平底便鞋",
633
+ "631": "洗剂",
634
+ "632": "扬声器",
635
+ "633": "放大镜",
636
+ "634": "锯木厂",
637
+ "635": "磁罗盘",
638
+ "636": "邮袋",
639
+ "637": "信箱",
640
+ "638": "女游泳衣",
641
+ "639": "有肩带浴衣",
642
+ "640": "窨井盖",
643
+ "641": "沙球(一种打击乐器)",
644
+ "642": "马林巴木琴",
645
+ "643": "面膜",
646
+ "644": "火柴",
647
+ "645": "花柱",
648
+ "646": "迷宫",
649
+ "647": "量杯",
650
+ "648": "药箱",
651
+ "649": "巨石,巨石结构",
652
+ "650": "麦克风",
653
+ "651": "微波炉",
654
+ "652": "军装",
655
+ "653": "奶桶",
656
+ "654": "迷你巴士",
657
+ "655": "迷你裙",
658
+ "656": "面包车",
659
+ "657": "导弹",
660
+ "658": "连指手套",
661
+ "659": "搅拌钵",
662
+ "660": "活动房屋(由汽车拖拉的)",
663
+ "661": "T型发动机小汽车",
664
+ "662": "调制解调器",
665
+ "663": "修道院",
666
+ "664": "显示器",
667
+ "665": "电瓶车",
668
+ "666": "砂浆",
669
+ "667": "学士",
670
+ "668": "清真寺",
671
+ "669": "蚊帐",
672
+ "670": "摩托车",
673
+ "671": "山地自行车",
674
+ "672": "登山帐",
675
+ "673": "鼠标,电脑鼠标",
676
+ "674": "捕鼠器",
677
+ "675": "搬家车",
678
+ "676": "口套",
679
+ "677": "钉子",
680
+ "678": "颈托",
681
+ "679": "项链",
682
+ "680": "乳头(瓶)",
683
+ "681": "笔记本,笔记本电脑",
684
+ "682": "方尖碑",
685
+ "683": "双簧管",
686
+ "684": "陶笛,卵形笛",
687
+ "685": "里程表",
688
+ "686": "滤油器",
689
+ "687": "风琴,管风琴",
690
+ "688": "示波器",
691
+ "689": "罩裙",
692
+ "690": "牛车",
693
+ "691": "氧气面罩",
694
+ "692": "包装",
695
+ "693": "船桨",
696
+ "694": "明轮,桨轮",
697
+ "695": "挂锁,扣锁",
698
+ "696": "画笔",
699
+ "697": "睡衣",
700
+ "698": "宫殿",
701
+ "699": "排箫,鸣管",
702
+ "700": "纸巾",
703
+ "701": "降落伞",
704
+ "702": "双杠",
705
+ "703": "公园长椅",
706
+ "704": "停车收费表,停车计时器",
707
+ "705": "客车,教练车",
708
+ "706": "露台,阳台",
709
+ "707": "付费电话",
710
+ "708": "基座,基脚",
711
+ "709": "铅笔盒",
712
+ "710": "卷笔刀",
713
+ "711": "香水(瓶)",
714
+ "712": "培养皿",
715
+ "713": "复印机",
716
+ "714": "拨弦片,拨子",
717
+ "715": "尖顶头盔",
718
+ "716": "栅栏,栅栏",
719
+ "717": "皮卡,皮卡车",
720
+ "718": "桥墩",
721
+ "719": "存钱罐",
722
+ "720": "药瓶",
723
+ "721": "枕头",
724
+ "722": "乒乓球",
725
+ "723": "风车",
726
+ "724": "海盗船",
727
+ "725": "水罐",
728
+ "726": "木工刨",
729
+ "727": "天文馆",
730
+ "728": "塑料袋",
731
+ "729": "板架",
732
+ "730": "犁型铲雪机",
733
+ "731": "手压皮碗泵",
734
+ "732": "宝丽来相机",
735
+ "733": "电线杆",
736
+ "734": "警车,巡逻车",
737
+ "735": "雨披",
738
+ "736": "台球桌",
739
+ "737": "充气饮料瓶",
740
+ "738": "花盆",
741
+ "739": "陶工旋盘",
742
+ "740": "电钻",
743
+ "741": "祈祷垫,地毯",
744
+ "742": "打印机",
745
+ "743": "监狱",
746
+ "744": "炮弹,导弹",
747
+ "745": "投影仪",
748
+ "746": "冰球",
749
+ "747": "沙包,吊球",
750
+ "748": "钱包",
751
+ "749": "羽管笔",
752
+ "750": "被子",
753
+ "751": "赛车",
754
+ "752": "球拍",
755
+ "753": "散热器",
756
+ "754": "收音机",
757
+ "755": "射电望远镜,无线电反射器",
758
+ "756": "雨桶",
759
+ "757": "休闲车,房车",
760
+ "758": "卷轴,卷筒",
761
+ "759": "反射式照相机",
762
+ "760": "冰箱,冰柜",
763
+ "761": "遥控器",
764
+ "762": "餐厅,饮食店,食堂",
765
+ "763": "左轮手枪",
766
+ "764": "步枪",
767
+ "765": "摇椅",
768
+ "766": "电转烤肉架",
769
+ "767": "橡皮",
770
+ "768": "橄榄球",
771
+ "769": "直尺",
772
+ "770": "跑步鞋",
773
+ "771": "保险柜",
774
+ "772": "安全别针",
775
+ "773": "盐瓶(调味用)",
776
+ "774": "凉鞋",
777
+ "775": "纱笼,围裙",
778
+ "776": "萨克斯管",
779
+ "777": "剑鞘",
780
+ "778": "秤,称重机",
781
+ "779": "校车",
782
+ "780": "帆船",
783
+ "781": "记分牌",
784
+ "782": "屏幕",
785
+ "783": "螺丝",
786
+ "784": "螺丝刀",
787
+ "785": "安全带",
788
+ "786": "缝纫机",
789
+ "787": "盾牌,盾牌",
790
+ "788": "皮鞋店,鞋店",
791
+ "789": "障子",
792
+ "790": "购物篮",
793
+ "791": "购物车",
794
+ "792": "铁锹",
795
+ "793": "浴帽",
796
+ "794": "浴帘",
797
+ "795": "滑雪板",
798
+ "796": "滑雪面罩",
799
+ "797": "睡袋",
800
+ "798": "滑尺",
801
+ "799": "滑动门",
802
+ "800": "角子老虎机",
803
+ "801": "潜水通气管",
804
+ "802": "雪橇",
805
+ "803": "扫雪机,扫雪机",
806
+ "804": "皂液器",
807
+ "805": "足球",
808
+ "806": "袜子",
809
+ "807": "碟式太阳能,太阳能集热器,太阳能炉",
810
+ "808": "宽边帽",
811
+ "809": "汤碗",
812
+ "810": "空格键",
813
+ "811": "空间加热器",
814
+ "812": "航天飞机",
815
+ "813": "铲(搅拌或涂敷用的)",
816
+ "814": "快艇",
817
+ "815": "蜘蛛网",
818
+ "816": "纺锤,纱锭",
819
+ "817": "跑车",
820
+ "818": "聚光灯",
821
+ "819": "舞台",
822
+ "820": "蒸汽机车",
823
+ "821": "钢拱桥",
824
+ "822": "钢滚筒",
825
+ "823": "听诊器",
826
+ "824": "女用披肩",
827
+ "825": "石头墙",
828
+ "826": "秒表",
829
+ "827": "火炉",
830
+ "828": "过滤器",
831
+ "829": "有轨电车,电车",
832
+ "830": "担架",
833
+ "831": "沙发床",
834
+ "832": "佛塔",
835
+ "833": "潜艇,潜水艇",
836
+ "834": "套装,衣服",
837
+ "835": "日晷",
838
+ "836": "太阳镜",
839
+ "837": "太阳镜,墨镜",
840
+ "838": "防晒霜,防晒剂",
841
+ "839": "悬索桥",
842
+ "840": "拖把",
843
+ "841": "运动衫",
844
+ "842": "游泳裤",
845
+ "843": "秋千",
846
+ "844": "开关,电器开关",
847
+ "845": "注射器",
848
+ "846": "台灯",
849
+ "847": "坦克,装甲战车,装甲战斗车辆",
850
+ "848": "磁带播放器",
851
+ "849": "茶壶",
852
+ "850": "泰迪,泰迪熊",
853
+ "851": "电视",
854
+ "852": "网球",
855
+ "853": "茅草,茅草屋顶",
856
+ "854": "幕布,剧院的帷幕",
857
+ "855": "顶针",
858
+ "856": "脱粒机",
859
+ "857": "宝座",
860
+ "858": "瓦屋顶",
861
+ "859": "烤面包机",
862
+ "860": "烟草店,烟草",
863
+ "861": "马桶",
864
+ "862": "火炬",
865
+ "863": "图腾柱",
866
+ "864": "拖车,牵引车,清障车",
867
+ "865": "玩具店",
868
+ "866": "拖拉机",
869
+ "867": "拖车,铰接式卡车",
870
+ "868": "托盘",
871
+ "869": "风衣",
872
+ "870": "三轮车",
873
+ "871": "三体船",
874
+ "872": "三脚架",
875
+ "873": "凯旋门",
876
+ "874": "无轨电车",
877
+ "875": "长号",
878
+ "876": "浴盆,浴缸",
879
+ "877": "旋转式栅门",
880
+ "878": "打字机键盘",
881
+ "879": "伞",
882
+ "880": "独轮车",
883
+ "881": "直立式钢琴",
884
+ "882": "真空吸尘器",
885
+ "883": "花瓶",
886
+ "884": "拱顶",
887
+ "885": "天鹅绒",
888
+ "886": "自动售货机",
889
+ "887": "祭服",
890
+ "888": "高架桥",
891
+ "889": "小提琴,小提琴",
892
+ "890": "排球",
893
+ "891": "松饼机",
894
+ "892": "挂钟",
895
+ "893": "钱包,皮夹",
896
+ "894": "衣柜,壁橱",
897
+ "895": "军用飞机",
898
+ "896": "洗脸盆,洗手盆",
899
+ "897": "洗衣机,自动洗衣机",
900
+ "898": "水瓶",
901
+ "899": "水壶",
902
+ "900": "水塔",
903
+ "901": "威士忌壶",
904
+ "902": "哨子",
905
+ "903": "假发",
906
+ "904": "纱窗",
907
+ "905": "百叶窗",
908
+ "906": "温莎领带",
909
+ "907": "葡萄酒瓶",
910
+ "908": "飞机翅膀,飞机",
911
+ "909": "炒菜锅",
912
+ "910": "木制的勺子",
913
+ "911": "毛织品,羊绒",
914
+ "912": "栅栏,围栏",
915
+ "913": "沉船",
916
+ "914": "双桅船",
917
+ "915": "蒙古包",
918
+ "916": "网站,互联网网站",
919
+ "917": "漫画",
920
+ "918": "纵横字谜",
921
+ "919": "路标",
922
+ "920": "交通信号灯",
923
+ "921": "防尘罩,书皮",
924
+ "922": "菜单",
925
+ "923": "盘子",
926
+ "924": "鳄梨酱",
927
+ "925": "清汤",
928
+ "926": "罐焖土豆烧肉",
929
+ "927": "蛋糕",
930
+ "928": "冰淇淋",
931
+ "929": "雪糕,冰棍,冰棒",
932
+ "930": "法式面包",
933
+ "931": "百吉饼",
934
+ "932": "椒盐脆饼",
935
+ "933": "芝士汉堡",
936
+ "934": "热狗",
937
+ "935": "土豆泥",
938
+ "936": "结球甘蓝",
939
+ "937": "西兰花",
940
+ "938": "菜花",
941
+ "939": "绿皮密生西葫芦",
942
+ "940": "西葫芦",
943
+ "941": "小青南瓜",
944
+ "942": "南瓜",
945
+ "943": "黄瓜",
946
+ "944": "朝鲜蓟",
947
+ "945": "甜椒",
948
+ "946": "刺棘蓟",
949
+ "947": "蘑菇",
950
+ "948": "绿苹果",
951
+ "949": "草莓",
952
+ "950": "橘子",
953
+ "951": "柠檬",
954
+ "952": "无花果",
955
+ "953": "菠萝",
956
+ "954": "香蕉",
957
+ "955": "菠萝蜜",
958
+ "956": "蛋奶冻苹果",
959
+ "957": "石榴",
960
+ "958": "干草",
961
+ "959": "烤面条加干酪沙司",
962
+ "960": "巧克力酱,巧克力糖浆",
963
+ "961": "面团",
964
+ "962": "瑞士肉包,肉饼",
965
+ "963": "披萨,披萨饼",
966
+ "964": "馅饼",
967
+ "965": "卷饼",
968
+ "966": "红葡萄酒",
969
+ "967": "意大利浓咖啡",
970
+ "968": "杯子",
971
+ "969": "蛋酒",
972
+ "970": "高山",
973
+ "971": "泡泡",
974
+ "972": "悬崖",
975
+ "973": "珊瑚礁",
976
+ "974": "间歇泉",
977
+ "975": "湖边,湖岸",
978
+ "976": "海角",
979
+ "977": "沙洲,沙坝",
980
+ "978": "海滨,海岸",
981
+ "979": "峡谷",
982
+ "980": "火山",
983
+ "981": "棒球,棒球运动员",
984
+ "982": "新郎",
985
+ "983": "潜水员",
986
+ "984": "油菜",
987
+ "985": "雏菊",
988
+ "986": "杓兰",
989
+ "987": "玉米",
990
+ "988": "橡子",
991
+ "989": "玫瑰果",
992
+ "990": "七叶树果实",
993
+ "991": "珊瑚菌",
994
+ "992": "木耳",
995
+ "993": "鹿花菌",
996
+ "994": "鬼笔菌",
997
+ "995": "地星(菌类)",
998
+ "996": "多叶奇果菌",
999
+ "997": "牛肝菌",
1000
+ "998": "玉米穗",
1001
+ "999": "卫生纸"
1002
+ }
labels/id2label_en.json ADDED
@@ -0,0 +1,1002 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "0": "tench, Tinca tinca",
3
+ "1": "goldfish, Carassius auratus",
4
+ "2": "great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias",
5
+ "3": "tiger shark, Galeocerdo cuvieri",
6
+ "4": "hammerhead, hammerhead shark",
7
+ "5": "electric ray, crampfish, numbfish, torpedo",
8
+ "6": "stingray",
9
+ "7": "cock",
10
+ "8": "hen",
11
+ "9": "ostrich, Struthio camelus",
12
+ "10": "brambling, Fringilla montifringilla",
13
+ "11": "goldfinch, Carduelis carduelis",
14
+ "12": "house finch, linnet, Carpodacus mexicanus",
15
+ "13": "junco, snowbird",
16
+ "14": "indigo bunting, indigo finch, indigo bird, Passerina cyanea",
17
+ "15": "robin, American robin, Turdus migratorius",
18
+ "16": "bulbul",
19
+ "17": "jay",
20
+ "18": "magpie",
21
+ "19": "chickadee",
22
+ "20": "water ouzel, dipper",
23
+ "21": "kite",
24
+ "22": "bald eagle, American eagle, Haliaeetus leucocephalus",
25
+ "23": "vulture",
26
+ "24": "great grey owl, great gray owl, Strix nebulosa",
27
+ "25": "European fire salamander, Salamandra salamandra",
28
+ "26": "common newt, Triturus vulgaris",
29
+ "27": "eft",
30
+ "28": "spotted salamander, Ambystoma maculatum",
31
+ "29": "axolotl, mud puppy, Ambystoma mexicanum",
32
+ "30": "bullfrog, Rana catesbeiana",
33
+ "31": "tree frog, tree-frog",
34
+ "32": "tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui",
35
+ "33": "loggerhead, loggerhead turtle, Caretta caretta",
36
+ "34": "leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea",
37
+ "35": "mud turtle",
38
+ "36": "terrapin",
39
+ "37": "box turtle, box tortoise",
40
+ "38": "banded gecko",
41
+ "39": "common iguana, iguana, Iguana iguana",
42
+ "40": "American chameleon, anole, Anolis carolinensis",
43
+ "41": "whiptail, whiptail lizard",
44
+ "42": "agama",
45
+ "43": "frilled lizard, Chlamydosaurus kingi",
46
+ "44": "alligator lizard",
47
+ "45": "Gila monster, Heloderma suspectum",
48
+ "46": "green lizard, Lacerta viridis",
49
+ "47": "African chameleon, Chamaeleo chamaeleon",
50
+ "48": "Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis",
51
+ "49": "African crocodile, Nile crocodile, Crocodylus niloticus",
52
+ "50": "American alligator, Alligator mississipiensis",
53
+ "51": "triceratops",
54
+ "52": "thunder snake, worm snake, Carphophis amoenus",
55
+ "53": "ringneck snake, ring-necked snake, ring snake",
56
+ "54": "hognose snake, puff adder, sand viper",
57
+ "55": "green snake, grass snake",
58
+ "56": "king snake, kingsnake",
59
+ "57": "garter snake, grass snake",
60
+ "58": "water snake",
61
+ "59": "vine snake",
62
+ "60": "night snake, Hypsiglena torquata",
63
+ "61": "boa constrictor, Constrictor constrictor",
64
+ "62": "rock python, rock snake, Python sebae",
65
+ "63": "Indian cobra, Naja naja",
66
+ "64": "green mamba",
67
+ "65": "sea snake",
68
+ "66": "horned viper, cerastes, sand viper, horned asp, Cerastes cornutus",
69
+ "67": "diamondback, diamondback rattlesnake, Crotalus adamanteus",
70
+ "68": "sidewinder, horned rattlesnake, Crotalus cerastes",
71
+ "69": "trilobite",
72
+ "70": "harvestman, daddy longlegs, Phalangium opilio",
73
+ "71": "scorpion",
74
+ "72": "black and gold garden spider, Argiope aurantia",
75
+ "73": "barn spider, Araneus cavaticus",
76
+ "74": "garden spider, Aranea diademata",
77
+ "75": "black widow, Latrodectus mactans",
78
+ "76": "tarantula",
79
+ "77": "wolf spider, hunting spider",
80
+ "78": "tick",
81
+ "79": "centipede",
82
+ "80": "black grouse",
83
+ "81": "ptarmigan",
84
+ "82": "ruffed grouse, partridge, Bonasa umbellus",
85
+ "83": "prairie chicken, prairie grouse, prairie fowl",
86
+ "84": "peacock",
87
+ "85": "quail",
88
+ "86": "partridge",
89
+ "87": "African grey, African gray, Psittacus erithacus",
90
+ "88": "macaw",
91
+ "89": "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita",
92
+ "90": "lorikeet",
93
+ "91": "coucal",
94
+ "92": "bee eater",
95
+ "93": "hornbill",
96
+ "94": "hummingbird",
97
+ "95": "jacamar",
98
+ "96": "toucan",
99
+ "97": "drake",
100
+ "98": "red-breasted merganser, Mergus serrator",
101
+ "99": "goose",
102
+ "100": "black swan, Cygnus atratus",
103
+ "101": "tusker",
104
+ "102": "echidna, spiny anteater, anteater",
105
+ "103": "platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus",
106
+ "104": "wallaby, brush kangaroo",
107
+ "105": "koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus",
108
+ "106": "wombat",
109
+ "107": "jellyfish",
110
+ "108": "sea anemone, anemone",
111
+ "109": "brain coral",
112
+ "110": "flatworm, platyhelminth",
113
+ "111": "nematode, nematode worm, roundworm",
114
+ "112": "conch",
115
+ "113": "snail",
116
+ "114": "slug",
117
+ "115": "sea slug, nudibranch",
118
+ "116": "chiton, coat-of-mail shell, sea cradle, polyplacophore",
119
+ "117": "chambered nautilus, pearly nautilus, nautilus",
120
+ "118": "Dungeness crab, Cancer magister",
121
+ "119": "rock crab, Cancer irroratus",
122
+ "120": "fiddler crab",
123
+ "121": "king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica",
124
+ "122": "American lobster, Northern lobster, Maine lobster, Homarus americanus",
125
+ "123": "spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish",
126
+ "124": "crayfish, crawfish, crawdad, crawdaddy",
127
+ "125": "hermit crab",
128
+ "126": "isopod",
129
+ "127": "white stork, Ciconia ciconia",
130
+ "128": "black stork, Ciconia nigra",
131
+ "129": "spoonbill",
132
+ "130": "flamingo",
133
+ "131": "little blue heron, Egretta caerulea",
134
+ "132": "American egret, great white heron, Egretta albus",
135
+ "133": "bittern",
136
+ "134": "crane",
137
+ "135": "limpkin, Aramus pictus",
138
+ "136": "European gallinule, Porphyrio porphyrio",
139
+ "137": "American coot, marsh hen, mud hen, water hen, Fulica americana",
140
+ "138": "bustard",
141
+ "139": "ruddy turnstone, Arenaria interpres",
142
+ "140": "red-backed sandpiper, dunlin, Erolia alpina",
143
+ "141": "redshank, Tringa totanus",
144
+ "142": "dowitcher",
145
+ "143": "oystercatcher, oyster catcher",
146
+ "144": "pelican",
147
+ "145": "king penguin, Aptenodytes patagonica",
148
+ "146": "albatross, mollymawk",
149
+ "147": "grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus",
150
+ "148": "killer whale, killer, orca, grampus, sea wolf, Orcinus orca",
151
+ "149": "dugong, Dugong dugon",
152
+ "150": "sea lion",
153
+ "151": "Chihuahua",
154
+ "152": "Japanese spaniel",
155
+ "153": "Maltese dog, Maltese terrier, Maltese",
156
+ "154": "Pekinese, Pekingese, Peke",
157
+ "155": "Shih-Tzu",
158
+ "156": "Blenheim spaniel",
159
+ "157": "papillon",
160
+ "158": "toy terrier",
161
+ "159": "Rhodesian ridgeback",
162
+ "160": "Afghan hound, Afghan",
163
+ "161": "basset, basset hound",
164
+ "162": "beagle",
165
+ "163": "bloodhound, sleuthhound",
166
+ "164": "bluetick",
167
+ "165": "black-and-tan coonhound",
168
+ "166": "Walker hound, Walker foxhound",
169
+ "167": "English foxhound",
170
+ "168": "redbone",
171
+ "169": "borzoi, Russian wolfhound",
172
+ "170": "Irish wolfhound",
173
+ "171": "Italian greyhound",
174
+ "172": "whippet",
175
+ "173": "Ibizan hound, Ibizan Podenco",
176
+ "174": "Norwegian elkhound, elkhound",
177
+ "175": "otterhound, otter hound",
178
+ "176": "Saluki, gazelle hound",
179
+ "177": "Scottish deerhound, deerhound",
180
+ "178": "Weimaraner",
181
+ "179": "Staffordshire bullterrier, Staffordshire bull terrier",
182
+ "180": "American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier",
183
+ "181": "Bedlington terrier",
184
+ "182": "Border terrier",
185
+ "183": "Kerry blue terrier",
186
+ "184": "Irish terrier",
187
+ "185": "Norfolk terrier",
188
+ "186": "Norwich terrier",
189
+ "187": "Yorkshire terrier",
190
+ "188": "wire-haired fox terrier",
191
+ "189": "Lakeland terrier",
192
+ "190": "Sealyham terrier, Sealyham",
193
+ "191": "Airedale, Airedale terrier",
194
+ "192": "cairn, cairn terrier",
195
+ "193": "Australian terrier",
196
+ "194": "Dandie Dinmont, Dandie Dinmont terrier",
197
+ "195": "Boston bull, Boston terrier",
198
+ "196": "miniature schnauzer",
199
+ "197": "giant schnauzer",
200
+ "198": "standard schnauzer",
201
+ "199": "Scotch terrier, Scottish terrier, Scottie",
202
+ "200": "Tibetan terrier, chrysanthemum dog",
203
+ "201": "silky terrier, Sydney silky",
204
+ "202": "soft-coated wheaten terrier",
205
+ "203": "West Highland white terrier",
206
+ "204": "Lhasa, Lhasa apso",
207
+ "205": "flat-coated retriever",
208
+ "206": "curly-coated retriever",
209
+ "207": "golden retriever",
210
+ "208": "Labrador retriever",
211
+ "209": "Chesapeake Bay retriever",
212
+ "210": "German short-haired pointer",
213
+ "211": "vizsla, Hungarian pointer",
214
+ "212": "English setter",
215
+ "213": "Irish setter, red setter",
216
+ "214": "Gordon setter",
217
+ "215": "Brittany spaniel",
218
+ "216": "clumber, clumber spaniel",
219
+ "217": "English springer, English springer spaniel",
220
+ "218": "Welsh springer spaniel",
221
+ "219": "cocker spaniel, English cocker spaniel, cocker",
222
+ "220": "Sussex spaniel",
223
+ "221": "Irish water spaniel",
224
+ "222": "kuvasz",
225
+ "223": "schipperke",
226
+ "224": "groenendael",
227
+ "225": "malinois",
228
+ "226": "briard",
229
+ "227": "kelpie",
230
+ "228": "komondor",
231
+ "229": "Old English sheepdog, bobtail",
232
+ "230": "Shetland sheepdog, Shetland sheep dog, Shetland",
233
+ "231": "collie",
234
+ "232": "Border collie",
235
+ "233": "Bouvier des Flandres, Bouviers des Flandres",
236
+ "234": "Rottweiler",
237
+ "235": "German shepherd, German shepherd dog, German police dog, alsatian",
238
+ "236": "Doberman, Doberman pinscher",
239
+ "237": "miniature pinscher",
240
+ "238": "Greater Swiss Mountain dog",
241
+ "239": "Bernese mountain dog",
242
+ "240": "Appenzeller",
243
+ "241": "EntleBucher",
244
+ "242": "boxer",
245
+ "243": "bull mastiff",
246
+ "244": "Tibetan mastiff",
247
+ "245": "French bulldog",
248
+ "246": "Great Dane",
249
+ "247": "Saint Bernard, St Bernard",
250
+ "248": "Eskimo dog, husky",
251
+ "249": "malamute, malemute, Alaskan malamute",
252
+ "250": "Siberian husky",
253
+ "251": "dalmatian, coach dog, carriage dog",
254
+ "252": "affenpinscher, monkey pinscher, monkey dog",
255
+ "253": "basenji",
256
+ "254": "pug, pug-dog",
257
+ "255": "Leonberg",
258
+ "256": "Newfoundland, Newfoundland dog",
259
+ "257": "Great Pyrenees",
260
+ "258": "Samoyed, Samoyede",
261
+ "259": "Pomeranian",
262
+ "260": "chow, chow chow",
263
+ "261": "keeshond",
264
+ "262": "Brabancon griffon",
265
+ "263": "Pembroke, Pembroke Welsh corgi",
266
+ "264": "Cardigan, Cardigan Welsh corgi",
267
+ "265": "toy poodle",
268
+ "266": "miniature poodle",
269
+ "267": "standard poodle",
270
+ "268": "Mexican hairless",
271
+ "269": "timber wolf, grey wolf, gray wolf, Canis lupus",
272
+ "270": "white wolf, Arctic wolf, Canis lupus tundrarum",
273
+ "271": "red wolf, maned wolf, Canis rufus, Canis niger",
274
+ "272": "coyote, prairie wolf, brush wolf, Canis latrans",
275
+ "273": "dingo, warrigal, warragal, Canis dingo",
276
+ "274": "dhole, Cuon alpinus",
277
+ "275": "African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus",
278
+ "276": "hyena, hyaena",
279
+ "277": "red fox, Vulpes vulpes",
280
+ "278": "kit fox, Vulpes macrotis",
281
+ "279": "Arctic fox, white fox, Alopex lagopus",
282
+ "280": "grey fox, gray fox, Urocyon cinereoargenteus",
283
+ "281": "tabby, tabby cat",
284
+ "282": "tiger cat",
285
+ "283": "Persian cat",
286
+ "284": "Siamese cat, Siamese",
287
+ "285": "Egyptian cat",
288
+ "286": "cougar, puma, catamount, mountain lion, painter, panther, Felis concolor",
289
+ "287": "lynx, catamount",
290
+ "288": "leopard, Panthera pardus",
291
+ "289": "snow leopard, ounce, Panthera uncia",
292
+ "290": "jaguar, panther, Panthera onca, Felis onca",
293
+ "291": "lion, king of beasts, Panthera leo",
294
+ "292": "tiger, Panthera tigris",
295
+ "293": "cheetah, chetah, Acinonyx jubatus",
296
+ "294": "brown bear, bruin, Ursus arctos",
297
+ "295": "American black bear, black bear, Ursus americanus, Euarctos americanus",
298
+ "296": "ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus",
299
+ "297": "sloth bear, Melursus ursinus, Ursus ursinus",
300
+ "298": "mongoose",
301
+ "299": "meerkat, mierkat",
302
+ "300": "tiger beetle",
303
+ "301": "ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle",
304
+ "302": "ground beetle, carabid beetle",
305
+ "303": "long-horned beetle, longicorn, longicorn beetle",
306
+ "304": "leaf beetle, chrysomelid",
307
+ "305": "dung beetle",
308
+ "306": "rhinoceros beetle",
309
+ "307": "weevil",
310
+ "308": "fly",
311
+ "309": "bee",
312
+ "310": "ant, emmet, pismire",
313
+ "311": "grasshopper, hopper",
314
+ "312": "cricket",
315
+ "313": "walking stick, walkingstick, stick insect",
316
+ "314": "cockroach, roach",
317
+ "315": "mantis, mantid",
318
+ "316": "cicada, cicala",
319
+ "317": "leafhopper",
320
+ "318": "lacewing, lacewing fly",
321
+ "319": "dragonfly, darning needle, devils darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk",
322
+ "320": "damselfly",
323
+ "321": "admiral",
324
+ "322": "ringlet, ringlet butterfly",
325
+ "323": "monarch, monarch butterfly, milkweed butterfly, Danaus plexippus",
326
+ "324": "cabbage butterfly",
327
+ "325": "sulphur butterfly, sulfur butterfly",
328
+ "326": "lycaenid, lycaenid butterfly",
329
+ "327": "starfish, sea star",
330
+ "328": "sea urchin",
331
+ "329": "sea cucumber, holothurian",
332
+ "330": "wood rabbit, cottontail, cottontail rabbit",
333
+ "331": "hare",
334
+ "332": "Angora, Angora rabbit",
335
+ "333": "hamster",
336
+ "334": "porcupine, hedgehog",
337
+ "335": "fox squirrel, eastern fox squirrel, Sciurus niger",
338
+ "336": "marmot",
339
+ "337": "beaver",
340
+ "338": "guinea pig, Cavia cobaya",
341
+ "339": "sorrel",
342
+ "340": "zebra",
343
+ "341": "hog, pig, grunter, squealer, Sus scrofa",
344
+ "342": "wild boar, boar, Sus scrofa",
345
+ "343": "warthog",
346
+ "344": "hippopotamus, hippo, river horse, Hippopotamus amphibius",
347
+ "345": "ox",
348
+ "346": "water buffalo, water ox, Asiatic buffalo, Bubalus bubalis",
349
+ "347": "bison",
350
+ "348": "ram, tup",
351
+ "349": "bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis",
352
+ "350": "ibex, Capra ibex",
353
+ "351": "hartebeest",
354
+ "352": "impala, Aepyceros melampus",
355
+ "353": "gazelle",
356
+ "354": "Arabian camel, dromedary, Camelus dromedarius",
357
+ "355": "llama",
358
+ "356": "weasel",
359
+ "357": "mink",
360
+ "358": "polecat, fitch, foulmart, foumart, Mustela putorius",
361
+ "359": "black-footed ferret, ferret, Mustela nigripes",
362
+ "360": "otter",
363
+ "361": "skunk, polecat, wood pussy",
364
+ "362": "badger",
365
+ "363": "armadillo",
366
+ "364": "three-toed sloth, ai, Bradypus tridactylus",
367
+ "365": "orangutan, orang, orangutang, Pongo pygmaeus",
368
+ "366": "gorilla, Gorilla gorilla",
369
+ "367": "chimpanzee, chimp, Pan troglodytes",
370
+ "368": "gibbon, Hylobates lar",
371
+ "369": "siamang, Hylobates syndactylus, Symphalangus syndactylus",
372
+ "370": "guenon, guenon monkey",
373
+ "371": "patas, hussar monkey, Erythrocebus patas",
374
+ "372": "baboon",
375
+ "373": "macaque",
376
+ "374": "langur",
377
+ "375": "colobus, colobus monkey",
378
+ "376": "proboscis monkey, Nasalis larvatus",
379
+ "377": "marmoset",
380
+ "378": "capuchin, ringtail, Cebus capucinus",
381
+ "379": "howler monkey, howler",
382
+ "380": "titi, titi monkey",
383
+ "381": "spider monkey, Ateles geoffroyi",
384
+ "382": "squirrel monkey, Saimiri sciureus",
385
+ "383": "Madagascar cat, ring-tailed lemur, Lemur catta",
386
+ "384": "indri, indris, Indri indri, Indri brevicaudatus",
387
+ "385": "Indian elephant, Elephas maximus",
388
+ "386": "African elephant, Loxodonta africana",
389
+ "387": "lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens",
390
+ "388": "giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca",
391
+ "389": "barracouta, snoek",
392
+ "390": "eel",
393
+ "391": "coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch",
394
+ "392": "rock beauty, Holocanthus tricolor",
395
+ "393": "anemone fish",
396
+ "394": "sturgeon",
397
+ "395": "gar, garfish, garpike, billfish, Lepisosteus osseus",
398
+ "396": "lionfish",
399
+ "397": "puffer, pufferfish, blowfish, globefish",
400
+ "398": "abacus",
401
+ "399": "abaya",
402
+ "400": "academic gown, academic robe, judge robe",
403
+ "401": "accordion, piano accordion, squeeze box",
404
+ "402": "acoustic guitar",
405
+ "403": "aircraft carrier, carrier, flattop, attack aircraft carrier",
406
+ "404": "airliner",
407
+ "405": "airship, dirigible",
408
+ "406": "altar",
409
+ "407": "ambulance",
410
+ "408": "amphibian, amphibious vehicle",
411
+ "409": "analog clock",
412
+ "410": "apiary, bee house",
413
+ "411": "apron",
414
+ "412": "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin",
415
+ "413": "assault rifle, assault gun",
416
+ "414": "backpack, back pack, knapsack, packsack, rucksack, haversack",
417
+ "415": "bakery, bakeshop, bakehouse",
418
+ "416": "balance beam, beam",
419
+ "417": "balloon",
420
+ "418": "ballpoint, ballpoint pen, ballpen, Biro",
421
+ "419": "Band Aid",
422
+ "420": "banjo",
423
+ "421": "bannister, banister, balustrade, balusters, handrail",
424
+ "422": "barbell",
425
+ "423": "barber chair",
426
+ "424": "barbershop",
427
+ "425": "barn",
428
+ "426": "barometer",
429
+ "427": "barrel, cask",
430
+ "428": "barrow, garden cart, lawn cart, wheelbarrow",
431
+ "429": "baseball",
432
+ "430": "basketball",
433
+ "431": "bassinet",
434
+ "432": "bassoon",
435
+ "433": "bathing cap, swimming cap",
436
+ "434": "bath towel",
437
+ "435": "bathtub, bathing tub, bath, tub",
438
+ "436": "beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon",
439
+ "437": "beacon, lighthouse, beacon light, pharos",
440
+ "438": "beaker",
441
+ "439": "bearskin, busby, shako",
442
+ "440": "beer bottle",
443
+ "441": "beer glass",
444
+ "442": "bell cote, bell cot",
445
+ "443": "bib",
446
+ "444": "bicycle-built-for-two, tandem bicycle, tandem",
447
+ "445": "bikini, two-piece",
448
+ "446": "binder, ring-binder",
449
+ "447": "binoculars, field glasses, opera glasses",
450
+ "448": "birdhouse",
451
+ "449": "boathouse",
452
+ "450": "bobsled, bobsleigh, bob",
453
+ "451": "bolo tie, bolo, bola tie, bola",
454
+ "452": "bonnet, poke bonnet",
455
+ "453": "bookcase",
456
+ "454": "bookshop, bookstore, bookstall",
457
+ "455": "bottlecap",
458
+ "456": "bow",
459
+ "457": "bow tie, bow-tie, bowtie",
460
+ "458": "brass, memorial tablet, plaque",
461
+ "459": "brassiere, bra, bandeau",
462
+ "460": "breakwater, groin, groyne, mole, bulwark, seawall, jetty",
463
+ "461": "breastplate, aegis, egis",
464
+ "462": "broom",
465
+ "463": "bucket, pail",
466
+ "464": "buckle",
467
+ "465": "bulletproof vest",
468
+ "466": "bullet train, bullet",
469
+ "467": "butcher shop, meat market",
470
+ "468": "cab, hack, taxi, taxicab",
471
+ "469": "caldron, cauldron",
472
+ "470": "candle, taper, wax light",
473
+ "471": "cannon",
474
+ "472": "canoe",
475
+ "473": "can opener, tin opener",
476
+ "474": "cardigan",
477
+ "475": "car mirror",
478
+ "476": "carousel, carrousel, merry-go-round, roundabout, whirligig",
479
+ "477": "carpenters kit, tool kit",
480
+ "478": "carton",
481
+ "479": "car wheel",
482
+ "480": "cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM",
483
+ "481": "cassette",
484
+ "482": "cassette player",
485
+ "483": "castle",
486
+ "484": "catamaran",
487
+ "485": "CD player",
488
+ "486": "cello, violoncello",
489
+ "487": "cellular telephone, cellular phone, cellphone, cell, mobile phone",
490
+ "488": "chain",
491
+ "489": "chainlink fence",
492
+ "490": "chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour",
493
+ "491": "chain saw, chainsaw",
494
+ "492": "chest",
495
+ "493": "chiffonier, commode",
496
+ "494": "chime, bell, gong",
497
+ "495": "china cabinet, china closet",
498
+ "496": "Christmas stocking",
499
+ "497": "church, church building",
500
+ "498": "cinema, movie theater, movie theatre, movie house, picture palace",
501
+ "499": "cleaver, meat cleaver, chopper",
502
+ "500": "cliff dwelling",
503
+ "501": "cloak",
504
+ "502": "clog, geta, patten, sabot",
505
+ "503": "cocktail shaker",
506
+ "504": "coffee mug",
507
+ "505": "coffeepot",
508
+ "506": "coil, spiral, volute, whorl, helix",
509
+ "507": "combination lock",
510
+ "508": "computer keyboard, keypad",
511
+ "509": "confectionery, confectionary, candy store",
512
+ "510": "container ship, containership, container vessel",
513
+ "511": "convertible",
514
+ "512": "corkscrew, bottle screw",
515
+ "513": "cornet, horn, trumpet, trump",
516
+ "514": "cowboy boot",
517
+ "515": "cowboy hat, ten-gallon hat",
518
+ "516": "cradle",
519
+ "517": "crane",
520
+ "518": "crash helmet",
521
+ "519": "crate",
522
+ "520": "crib, cot",
523
+ "521": "Crock Pot",
524
+ "522": "croquet ball",
525
+ "523": "crutch",
526
+ "524": "cuirass",
527
+ "525": "dam, dike, dyke",
528
+ "526": "desk",
529
+ "527": "desktop computer",
530
+ "528": "dial telephone, dial phone",
531
+ "529": "diaper, nappy, napkin",
532
+ "530": "digital clock",
533
+ "531": "digital watch",
534
+ "532": "dining table, board",
535
+ "533": "dishrag, dishcloth",
536
+ "534": "dishwasher, dish washer, dishwashing machine",
537
+ "535": "disk brake, disc brake",
538
+ "536": "dock, dockage, docking facility",
539
+ "537": "dogsled, dog sled, dog sleigh",
540
+ "538": "dome",
541
+ "539": "doormat, welcome mat",
542
+ "540": "drilling platform, offshore rig",
543
+ "541": "drum, membranophone, tympan",
544
+ "542": "drumstick",
545
+ "543": "dumbbell",
546
+ "544": "Dutch oven",
547
+ "545": "electric fan, blower",
548
+ "546": "electric guitar",
549
+ "547": "electric locomotive",
550
+ "548": "entertainment center",
551
+ "549": "envelope",
552
+ "550": "espresso maker",
553
+ "551": "face powder",
554
+ "552": "feather boa, boa",
555
+ "553": "file, file cabinet, filing cabinet",
556
+ "554": "fireboat",
557
+ "555": "fire engine, fire truck",
558
+ "556": "fire screen, fireguard",
559
+ "557": "flagpole, flagstaff",
560
+ "558": "flute, transverse flute",
561
+ "559": "folding chair",
562
+ "560": "football helmet",
563
+ "561": "forklift",
564
+ "562": "fountain",
565
+ "563": "fountain pen",
566
+ "564": "four-poster",
567
+ "565": "freight car",
568
+ "566": "French horn, horn",
569
+ "567": "frying pan, frypan, skillet",
570
+ "568": "fur coat",
571
+ "569": "garbage truck, dustcart",
572
+ "570": "gasmask, respirator, gas helmet",
573
+ "571": "gas pump, gasoline pump, petrol pump, island dispenser",
574
+ "572": "goblet",
575
+ "573": "go-kart",
576
+ "574": "golf ball",
577
+ "575": "golfcart, golf cart",
578
+ "576": "gondola",
579
+ "577": "gong, tam-tam",
580
+ "578": "gown",
581
+ "579": "grand piano, grand",
582
+ "580": "greenhouse, nursery, glasshouse",
583
+ "581": "grille, radiator grille",
584
+ "582": "grocery store, grocery, food market, market",
585
+ "583": "guillotine",
586
+ "584": "hair slide",
587
+ "585": "hair spray",
588
+ "586": "half track",
589
+ "587": "hammer",
590
+ "588": "hamper",
591
+ "589": "hand blower, blow dryer, blow drier, hair dryer, hair drier",
592
+ "590": "hand-held computer, hand-held microcomputer",
593
+ "591": "handkerchief, hankie, hanky, hankey",
594
+ "592": "hard disc, hard disk, fixed disk",
595
+ "593": "harmonica, mouth organ, harp, mouth harp",
596
+ "594": "harp",
597
+ "595": "harvester, reaper",
598
+ "596": "hatchet",
599
+ "597": "holster",
600
+ "598": "home theater, home theatre",
601
+ "599": "honeycomb",
602
+ "600": "hook, claw",
603
+ "601": "hoopskirt, crinoline",
604
+ "602": "horizontal bar, high bar",
605
+ "603": "horse cart, horse-cart",
606
+ "604": "hourglass",
607
+ "605": "iPod",
608
+ "606": "iron, smoothing iron",
609
+ "607": "jack-o-lantern",
610
+ "608": "jean, blue jean, denim",
611
+ "609": "jeep, landrover",
612
+ "610": "jersey, T-shirt, tee shirt",
613
+ "611": "jigsaw puzzle",
614
+ "612": "jinrikisha, ricksha, rickshaw",
615
+ "613": "joystick",
616
+ "614": "kimono",
617
+ "615": "knee pad",
618
+ "616": "knot",
619
+ "617": "lab coat, laboratory coat",
620
+ "618": "ladle",
621
+ "619": "lampshade, lamp shade",
622
+ "620": "laptop, laptop computer",
623
+ "621": "lawn mower, mower",
624
+ "622": "lens cap, lens cover",
625
+ "623": "letter opener, paper knife, paperknife",
626
+ "624": "library",
627
+ "625": "lifeboat",
628
+ "626": "lighter, light, igniter, ignitor",
629
+ "627": "limousine, limo",
630
+ "628": "liner, ocean liner",
631
+ "629": "lipstick, lip rouge",
632
+ "630": "Loafer",
633
+ "631": "lotion",
634
+ "632": "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system",
635
+ "633": "loupe, jewelers loupe",
636
+ "634": "lumbermill, sawmill",
637
+ "635": "magnetic compass",
638
+ "636": "mailbag, postbag",
639
+ "637": "mailbox, letter box",
640
+ "638": "maillot",
641
+ "639": "maillot, tank suit",
642
+ "640": "manhole cover",
643
+ "641": "maraca",
644
+ "642": "marimba, xylophone",
645
+ "643": "mask",
646
+ "644": "matchstick",
647
+ "645": "maypole",
648
+ "646": "maze, labyrinth",
649
+ "647": "measuring cup",
650
+ "648": "medicine chest, medicine cabinet",
651
+ "649": "megalith, megalithic structure",
652
+ "650": "microphone, mike",
653
+ "651": "microwave, microwave oven",
654
+ "652": "military uniform",
655
+ "653": "milk can",
656
+ "654": "minibus",
657
+ "655": "miniskirt, mini",
658
+ "656": "minivan",
659
+ "657": "missile",
660
+ "658": "mitten",
661
+ "659": "mixing bowl",
662
+ "660": "mobile home, manufactured home",
663
+ "661": "Model T",
664
+ "662": "modem",
665
+ "663": "monastery",
666
+ "664": "monitor",
667
+ "665": "moped",
668
+ "666": "mortar",
669
+ "667": "mortarboard",
670
+ "668": "mosque",
671
+ "669": "mosquito net",
672
+ "670": "motor scooter, scooter",
673
+ "671": "mountain bike, all-terrain bike, off-roader",
674
+ "672": "mountain tent",
675
+ "673": "mouse, computer mouse",
676
+ "674": "mousetrap",
677
+ "675": "moving van",
678
+ "676": "muzzle",
679
+ "677": "nail",
680
+ "678": "neck brace",
681
+ "679": "necklace",
682
+ "680": "nipple",
683
+ "681": "notebook, notebook computer",
684
+ "682": "obelisk",
685
+ "683": "oboe, hautboy, hautbois",
686
+ "684": "ocarina, sweet potato",
687
+ "685": "odometer, hodometer, mileometer, milometer",
688
+ "686": "oil filter",
689
+ "687": "organ, pipe organ",
690
+ "688": "oscilloscope, scope, cathode-ray oscilloscope, CRO",
691
+ "689": "overskirt",
692
+ "690": "oxcart",
693
+ "691": "oxygen mask",
694
+ "692": "packet",
695
+ "693": "paddle, boat paddle",
696
+ "694": "paddlewheel, paddle wheel",
697
+ "695": "padlock",
698
+ "696": "paintbrush",
699
+ "697": "pajama, pyjama, pjs, jammies",
700
+ "698": "palace",
701
+ "699": "panpipe, pandean pipe, syrinx",
702
+ "700": "paper towel",
703
+ "701": "parachute, chute",
704
+ "702": "parallel bars, bars",
705
+ "703": "park bench",
706
+ "704": "parking meter",
707
+ "705": "passenger car, coach, carriage",
708
+ "706": "patio, terrace",
709
+ "707": "pay-phone, pay-station",
710
+ "708": "pedestal, plinth, footstall",
711
+ "709": "pencil box, pencil case",
712
+ "710": "pencil sharpener",
713
+ "711": "perfume, essence",
714
+ "712": "Petri dish",
715
+ "713": "photocopier",
716
+ "714": "pick, plectrum, plectron",
717
+ "715": "pickelhaube",
718
+ "716": "picket fence, paling",
719
+ "717": "pickup, pickup truck",
720
+ "718": "pier",
721
+ "719": "piggy bank, penny bank",
722
+ "720": "pill bottle",
723
+ "721": "pillow",
724
+ "722": "ping-pong ball",
725
+ "723": "pinwheel",
726
+ "724": "pirate, pirate ship",
727
+ "725": "pitcher, ewer",
728
+ "726": "plane, carpenters plane, woodworking plane",
729
+ "727": "planetarium",
730
+ "728": "plastic bag",
731
+ "729": "plate rack",
732
+ "730": "plow, plough",
733
+ "731": "plunger, plumbers helper",
734
+ "732": "Polaroid camera, Polaroid Land camera",
735
+ "733": "pole",
736
+ "734": "police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria",
737
+ "735": "poncho",
738
+ "736": "pool table, billiard table, snooker table",
739
+ "737": "pop bottle, soda bottle",
740
+ "738": "pot, flowerpot",
741
+ "739": "potters wheel",
742
+ "740": "power drill",
743
+ "741": "prayer rug, prayer mat",
744
+ "742": "printer",
745
+ "743": "prison, prison house",
746
+ "744": "projectile, missile",
747
+ "745": "projector",
748
+ "746": "puck, hockey puck",
749
+ "747": "punching bag, punch bag, punching ball, punchball",
750
+ "748": "purse",
751
+ "749": "quill, quill pen",
752
+ "750": "quilt, comforter, comfort, puff",
753
+ "751": "racer, race car, racing car",
754
+ "752": "racket, racquet",
755
+ "753": "radiator",
756
+ "754": "radio, wireless",
757
+ "755": "radio telescope, radio reflector",
758
+ "756": "rain barrel",
759
+ "757": "recreational vehicle, RV, R.V.",
760
+ "758": "reel",
761
+ "759": "reflex camera",
762
+ "760": "refrigerator, icebox",
763
+ "761": "remote control, remote",
764
+ "762": "restaurant, eating house, eating place, eatery",
765
+ "763": "revolver, six-gun, six-shooter",
766
+ "764": "rifle",
767
+ "765": "rocking chair, rocker",
768
+ "766": "rotisserie",
769
+ "767": "rubber eraser, rubber, pencil eraser",
770
+ "768": "rugby ball",
771
+ "769": "rule, ruler",
772
+ "770": "running shoe",
773
+ "771": "safe",
774
+ "772": "safety pin",
775
+ "773": "saltshaker, salt shaker",
776
+ "774": "sandal",
777
+ "775": "sarong",
778
+ "776": "sax, saxophone",
779
+ "777": "scabbard",
780
+ "778": "scale, weighing machine",
781
+ "779": "school bus",
782
+ "780": "schooner",
783
+ "781": "scoreboard",
784
+ "782": "screen, CRT screen",
785
+ "783": "screw",
786
+ "784": "screwdriver",
787
+ "785": "seat belt, seatbelt",
788
+ "786": "sewing machine",
789
+ "787": "shield, buckler",
790
+ "788": "shoe shop, shoe-shop, shoe store",
791
+ "789": "shoji",
792
+ "790": "shopping basket",
793
+ "791": "shopping cart",
794
+ "792": "shovel",
795
+ "793": "shower cap",
796
+ "794": "shower curtain",
797
+ "795": "ski",
798
+ "796": "ski mask",
799
+ "797": "sleeping bag",
800
+ "798": "slide rule, slipstick",
801
+ "799": "sliding door",
802
+ "800": "slot, one-armed bandit",
803
+ "801": "snorkel",
804
+ "802": "snowmobile",
805
+ "803": "snowplow, snowplough",
806
+ "804": "soap dispenser",
807
+ "805": "soccer ball",
808
+ "806": "sock",
809
+ "807": "solar dish, solar collector, solar furnace",
810
+ "808": "sombrero",
811
+ "809": "soup bowl",
812
+ "810": "space bar",
813
+ "811": "space heater",
814
+ "812": "space shuttle",
815
+ "813": "spatula",
816
+ "814": "speedboat",
817
+ "815": "spider web, spiders web",
818
+ "816": "spindle",
819
+ "817": "sports car, sport car",
820
+ "818": "spotlight, spot",
821
+ "819": "stage",
822
+ "820": "steam locomotive",
823
+ "821": "steel arch bridge",
824
+ "822": "steel drum",
825
+ "823": "stethoscope",
826
+ "824": "stole",
827
+ "825": "stone wall",
828
+ "826": "stopwatch, stop watch",
829
+ "827": "stove",
830
+ "828": "strainer",
831
+ "829": "streetcar, tram, tramcar, trolley, trolley car",
832
+ "830": "stretcher",
833
+ "831": "studio couch, day bed",
834
+ "832": "stupa, tope",
835
+ "833": "submarine, pigboat, sub, U-boat",
836
+ "834": "suit, suit of clothes",
837
+ "835": "sundial",
838
+ "836": "sunglass",
839
+ "837": "sunglasses, dark glasses, shades",
840
+ "838": "sunscreen, sunblock, sun blocker",
841
+ "839": "suspension bridge",
842
+ "840": "swab, swob, mop",
843
+ "841": "sweatshirt",
844
+ "842": "swimming trunks, bathing trunks",
845
+ "843": "swing",
846
+ "844": "switch, electric switch, electrical switch",
847
+ "845": "syringe",
848
+ "846": "table lamp",
849
+ "847": "tank, army tank, armored combat vehicle, armoured combat vehicle",
850
+ "848": "tape player",
851
+ "849": "teapot",
852
+ "850": "teddy, teddy bear",
853
+ "851": "television, television system",
854
+ "852": "tennis ball",
855
+ "853": "thatch, thatched roof",
856
+ "854": "theater curtain, theatre curtain",
857
+ "855": "thimble",
858
+ "856": "thresher, thrasher, threshing machine",
859
+ "857": "throne",
860
+ "858": "tile roof",
861
+ "859": "toaster",
862
+ "860": "tobacco shop, tobacconist shop, tobacconist",
863
+ "861": "toilet seat",
864
+ "862": "torch",
865
+ "863": "totem pole",
866
+ "864": "tow truck, tow car, wrecker",
867
+ "865": "toyshop",
868
+ "866": "tractor",
869
+ "867": "trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi",
870
+ "868": "tray",
871
+ "869": "trench coat",
872
+ "870": "tricycle, trike, velocipede",
873
+ "871": "trimaran",
874
+ "872": "tripod",
875
+ "873": "triumphal arch",
876
+ "874": "trolleybus, trolley coach, trackless trolley",
877
+ "875": "trombone",
878
+ "876": "tub, vat",
879
+ "877": "turnstile",
880
+ "878": "typewriter keyboard",
881
+ "879": "umbrella",
882
+ "880": "unicycle, monocycle",
883
+ "881": "upright, upright piano",
884
+ "882": "vacuum, vacuum cleaner",
885
+ "883": "vase",
886
+ "884": "vault",
887
+ "885": "velvet",
888
+ "886": "vending machine",
889
+ "887": "vestment",
890
+ "888": "viaduct",
891
+ "889": "violin, fiddle",
892
+ "890": "volleyball",
893
+ "891": "waffle iron",
894
+ "892": "wall clock",
895
+ "893": "wallet, billfold, notecase, pocketbook",
896
+ "894": "wardrobe, closet, press",
897
+ "895": "warplane, military plane",
898
+ "896": "washbasin, handbasin, washbowl, lavabo, wash-hand basin",
899
+ "897": "washer, automatic washer, washing machine",
900
+ "898": "water bottle",
901
+ "899": "water jug",
902
+ "900": "water tower",
903
+ "901": "whiskey jug",
904
+ "902": "whistle",
905
+ "903": "wig",
906
+ "904": "window screen",
907
+ "905": "window shade",
908
+ "906": "Windsor tie",
909
+ "907": "wine bottle",
910
+ "908": "wing",
911
+ "909": "wok",
912
+ "910": "wooden spoon",
913
+ "911": "wool, woolen, woollen",
914
+ "912": "worm fence, snake fence, snake-rail fence, Virginia fence",
915
+ "913": "wreck",
916
+ "914": "yawl",
917
+ "915": "yurt",
918
+ "916": "web site, website, internet site, site",
919
+ "917": "comic book",
920
+ "918": "crossword puzzle, crossword",
921
+ "919": "street sign",
922
+ "920": "traffic light, traffic signal, stoplight",
923
+ "921": "book jacket, dust cover, dust jacket, dust wrapper",
924
+ "922": "menu",
925
+ "923": "plate",
926
+ "924": "guacamole",
927
+ "925": "consomme",
928
+ "926": "hot pot, hotpot",
929
+ "927": "trifle",
930
+ "928": "ice cream, icecream",
931
+ "929": "ice lolly, lolly, lollipop, popsicle",
932
+ "930": "French loaf",
933
+ "931": "bagel, beigel",
934
+ "932": "pretzel",
935
+ "933": "cheeseburger",
936
+ "934": "hotdog, hot dog, red hot",
937
+ "935": "mashed potato",
938
+ "936": "head cabbage",
939
+ "937": "broccoli",
940
+ "938": "cauliflower",
941
+ "939": "zucchini, courgette",
942
+ "940": "spaghetti squash",
943
+ "941": "acorn squash",
944
+ "942": "butternut squash",
945
+ "943": "cucumber, cuke",
946
+ "944": "artichoke, globe artichoke",
947
+ "945": "bell pepper",
948
+ "946": "cardoon",
949
+ "947": "mushroom",
950
+ "948": "Granny Smith",
951
+ "949": "strawberry",
952
+ "950": "orange",
953
+ "951": "lemon",
954
+ "952": "fig",
955
+ "953": "pineapple, ananas",
956
+ "954": "banana",
957
+ "955": "jackfruit, jak, jack",
958
+ "956": "custard apple",
959
+ "957": "pomegranate",
960
+ "958": "hay",
961
+ "959": "carbonara",
962
+ "960": "chocolate sauce, chocolate syrup",
963
+ "961": "dough",
964
+ "962": "meat loaf, meatloaf",
965
+ "963": "pizza, pizza pie",
966
+ "964": "potpie",
967
+ "965": "burrito",
968
+ "966": "red wine",
969
+ "967": "espresso",
970
+ "968": "cup",
971
+ "969": "eggnog",
972
+ "970": "alp",
973
+ "971": "bubble",
974
+ "972": "cliff, drop, drop-off",
975
+ "973": "coral reef",
976
+ "974": "geyser",
977
+ "975": "lakeside, lakeshore",
978
+ "976": "promontory, headland, head, foreland",
979
+ "977": "sandbar, sand bar",
980
+ "978": "seashore, coast, seacoast, sea-coast",
981
+ "979": "valley, vale",
982
+ "980": "volcano",
983
+ "981": "ballplayer, baseball player",
984
+ "982": "groom, bridegroom",
985
+ "983": "scuba diver",
986
+ "984": "rapeseed",
987
+ "985": "daisy",
988
+ "986": "yellow ladys slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum",
989
+ "987": "corn",
990
+ "988": "acorn",
991
+ "989": "hip, rose hip, rosehip",
992
+ "990": "buckeye, horse chestnut, conker",
993
+ "991": "coral fungus",
994
+ "992": "agaric",
995
+ "993": "gyromitra",
996
+ "994": "stinkhorn, carrion fungus",
997
+ "995": "earthstar",
998
+ "996": "hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa",
999
+ "997": "bolete",
1000
+ "998": "ear, spike, capitulum",
1001
+ "999": "toilet tissue, toilet paper, bathroom tissue"
1002
+ }
labels/imagenet_labels.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ImageNet-1k class labels for ADM class-conditional generation.
2
+
3
+ Labels are stored as Hugging Face-style ``id2label`` JSON maps (string keys ``"0"``–``"999"``).
4
+ Each value is a comma-separated list of synonyms for that class id.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import json
10
+ from pathlib import Path
11
+ from typing import Literal
12
+
13
+ Language = Literal["en", "cn"]
14
+
15
+ _LABELS_DIR = Path(__file__).resolve().parent
16
+
17
+
18
+ def load_id2label(
19
+ labels_dir: Path | str | None = None,
20
+ lang: Language = "en",
21
+ ) -> dict[int, str]:
22
+ """Load ``id2label`` from ``id2label_en.json`` or ``id2label_cn.json``."""
23
+ root = Path(labels_dir) if labels_dir is not None else _LABELS_DIR
24
+ filename = "id2label_en.json" if lang == "en" else "id2label_cn.json"
25
+ path = root / filename
26
+ if not path.exists():
27
+ raise FileNotFoundError(f"ImageNet label file not found: {path}")
28
+
29
+ raw = json.loads(path.read_text(encoding="utf-8"))
30
+ return {int(key): value for key, value in raw.items()}
31
+
32
+
33
+ def build_label2id(id2label: dict[int, str]) -> dict[str, int]:
34
+ """Build a synonym -> class id map from an ``id2label`` dict (DiT-style)."""
35
+ labels: dict[str, int] = {}
36
+ for class_id, value in id2label.items():
37
+ for synonym in value.split(","):
38
+ synonym = synonym.strip()
39
+ if synonym:
40
+ labels[synonym] = int(class_id)
41
+ return dict(sorted(labels.items()))
42
+
43
+
44
+ def resolve_label_ids(
45
+ labels: str | list[str],
46
+ label2id: dict[str, int],
47
+ *,
48
+ lang: Language = "en",
49
+ ) -> list[int]:
50
+ """Map one or more label strings to ImageNet class ids."""
51
+ if isinstance(labels, str):
52
+ labels = [labels]
53
+
54
+ missing = [label for label in labels if label not in label2id]
55
+ if missing:
56
+ preview = ", ".join(list(label2id.keys())[:8])
57
+ raise ValueError(
58
+ f"Unknown label(s) for lang={lang!r}: {missing}. "
59
+ f"Example valid labels: {preview}, ..."
60
+ )
61
+ return [label2id[label] for label in labels]