data-archetype commited on
Commit
6199cd7
·
verified ·
1 Parent(s): 043616b

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. capacitor_diffae/model.py +14 -5
capacitor_diffae/model.py CHANGED
@@ -206,19 +206,23 @@ class CapacitorDiffAE(nn.Module):
206
  return z * std.to(device=z.device) + mean.to(device=z.device)
207
 
208
  def encode(self, images: Tensor) -> Tensor:
209
- """Encode images to latents (posterior mode).
 
 
 
210
 
211
  Args:
212
  images: [B, 3, H, W] in [-1, 1], H and W divisible by patch_size.
213
 
214
  Returns:
215
- Latents [B, bottleneck_dim, H/patch, W/patch].
216
  """
217
  try:
218
  model_dtype = next(self.parameters()).dtype
219
  except StopIteration:
220
  model_dtype = torch.float32
221
- return self.encoder(images.to(dtype=model_dtype))
 
222
 
223
  def encode_posterior(self, images: Tensor) -> EncoderPosterior:
224
  """Encode images and return the full posterior (mean + logsnr).
@@ -244,10 +248,12 @@ class CapacitorDiffAE(nn.Module):
244
  *,
245
  inference_config: CapacitorDiffAEInferenceConfig | None = None,
246
  ) -> Tensor:
247
- """Decode latents to images via VP diffusion.
 
 
248
 
249
  Args:
250
- latents: [B, bottleneck_dim, h, w] encoder latents.
251
  height: Output image height (divisible by patch_size).
252
  width: Output image width (divisible by patch_size).
253
  inference_config: Optional inference parameters.
@@ -265,6 +271,9 @@ class CapacitorDiffAE(nn.Module):
265
  except StopIteration:
266
  model_dtype = torch.float32
267
 
 
 
 
268
  if height % config.patch_size != 0 or width % config.patch_size != 0:
269
  raise ValueError(
270
  f"height={height} and width={width} must be divisible by "
 
206
  return z * std.to(device=z.device) + mean.to(device=z.device)
207
 
208
  def encode(self, images: Tensor) -> Tensor:
209
+ """Encode images to whitened latents (posterior mode).
210
+
211
+ Returns latents whitened using per-channel running stats, ready for
212
+ use by downstream latent-space diffusion models.
213
 
214
  Args:
215
  images: [B, 3, H, W] in [-1, 1], H and W divisible by patch_size.
216
 
217
  Returns:
218
+ Whitened latents [B, bottleneck_dim, H/patch, W/patch].
219
  """
220
  try:
221
  model_dtype = next(self.parameters()).dtype
222
  except StopIteration:
223
  model_dtype = torch.float32
224
+ z = self.encoder(images.to(dtype=model_dtype))
225
+ return self.whiten(z).to(dtype=model_dtype)
226
 
227
  def encode_posterior(self, images: Tensor) -> EncoderPosterior:
228
  """Encode images and return the full posterior (mean + logsnr).
 
248
  *,
249
  inference_config: CapacitorDiffAEInferenceConfig | None = None,
250
  ) -> Tensor:
251
+ """Decode whitened latents to images via VP diffusion.
252
+
253
+ Latents are dewhitened internally before being passed to the decoder.
254
 
255
  Args:
256
+ latents: [B, bottleneck_dim, h, w] whitened encoder latents.
257
  height: Output image height (divisible by patch_size).
258
  width: Output image width (divisible by patch_size).
259
  inference_config: Optional inference parameters.
 
271
  except StopIteration:
272
  model_dtype = torch.float32
273
 
274
+ # Dewhiten back to raw encoder scale for the decoder
275
+ latents = self.dewhiten(latents).to(dtype=model_dtype)
276
+
277
  if height % config.patch_size != 0 or width % config.patch_size != 0:
278
  raise ValueError(
279
  f"height={height} and width={width} must be divisible by "