FlashSR-MLX-4bit

MLX port of FlashSR — a single-step distilled latent-diffusion audio super-resolution model — quantized to INT4 weight-only for on-device inference on Apple Silicon. Upsamples any-rate input (mono) to 48 kHz in one diffusion pass with a specialised BigVGAN-flavour SR vocoder. FlashSR is the distilled student of AudioSR.

Model

Total parameters 638 M (VAE 223 M + UNet 258 M + Vocoder 157 M)
Diffusion steps 1 (distilled v-prediction student, cosine schedule, t=999)
Quantization INT4 weight-only, group size 64, mode mlx_affine_flat
Format MLX safetensors (single combined bundle)
Sample rate 48 kHz mono out (any-rate mono in)
Frame length 5.12 s (245 760 samples) per forward
Bundle size 346 MB on disk
Source jakeoneijk/FlashSR_Inference

Files

File Size Description
model.safetensors 346 MB INT4-quantized VAE + UNet + SR Vocoder weights
config.json ~70 KB Sub-model configs + quantization metadata + original shape table for dequant-on-load

The three sub-models share one safetensors file with vae.*, ldm.*, voc.* key prefixes. config.quantized_shapes records each tensor's pre-flatten shape so mx.dequantize can rebuild conv weight tensors at load time.

Performance (Apple Silicon, M-series, 5.12 s @ 48 kHz)

Metric Value
Real-time factor (wall / audio) 1.10
Load time 0.17 s (dequant materialises bf16 weights once)
SNR vs FP16 reference +29.4 dB
Cosine similarity vs FP16 0.9994
Peak amplitude preservation 1.000

INT4 is the recommended deployment variant — well above music-perceptual threshold and the smallest viable on-device bundle.

Usage

from huggingface_hub import snapshot_download
import mlx.core as mx
import numpy as np
import scipy.io.wavfile as wf
from scipy.signal import resample_poly

bundle = snapshot_download("aufklarer/FlashSR-MLX-4bit")
# See https://github.com/soniqo/speech-swift for production usage.

# Toy Python demo (requires the matching MLX FlashSR runtime):
sr, audio = wf.read("lr.wav")
audio = audio.astype(np.float32) / 32767.0
audio_48 = resample_poly(audio, 48000, sr).astype(np.float32)

from flashsr import FlashSR  # from this repo's export/
model = FlashSR(bundle)
hr = model(mx.array(audio_48), seed=42)
mx.eval(hr)
wf.write("hr.wav", 48000, (np.clip(np.array(hr), -1, 1) * 32767).astype(np.int16))

Source

License

CC-BY-NC 4.0 — inherited from upstream FlashSR weights. Non-commercial use only.

Downloads last month
-
Safetensors
Model size
0.1B params
Tensor type
BF16
·
U32
·
MLX
Hardware compatibility
Log In to add your hardware

Quantized

Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Collection including aufklarer/FlashSR-MLX-4bit

Paper for aufklarer/FlashSR-MLX-4bit