| --- |
| base_model: black-forest-labs/FLUX.1-dev |
| library_name: diffusers |
| base_model_relation: quantized |
| tags: |
| - quantization |
| --- |
| # Visual comparison of Flux-dev model outputs using BF16 and torchao float8_weight_only quantization |
|
|
| <td style="text-align: center;"> |
| BF16<br> |
| <medium-zoom background="rgba(0,0,0,.7)"><img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/quantization-backends-diffusers/combined_flux-dev_bf16_combined.png" alt="Flux-dev output with BF16: Baroque, Futurist, Noir styles"></medium-zoom> |
| </td> |
| <td style="text-align: center;"> |
| torchao fp8_weight_only<br> |
| <medium-zoom background="rgba(0,0,0,.7)"><img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/quantization-backends-diffusers/combined_flux-dev_torchao_fp8_combined.png" alt="torchao fp8_weight_only Output"> |
| </td> |
|
|
| # Usage with Diffusers |
|
|
| To use this quantized FLUX.1 [dev] checkpoint, you need to install the 🧨 diffusers and torchao library: |
|
|
| ``` |
| pip install -U diffusers |
| pip install -U torchao |
| ``` |
|
|
| After installing the required library, you can run the following script: |
|
|
| ```python |
| from diffusers import FluxPipeline |
| pipe = FluxPipeline.from_pretrained( |
| "diffusers/FLUX.1-dev-torchao-fp8", |
| torch_dtype=torch.bfloat16, |
| use_safetensors=False, |
| device_map="balanced" |
| ) |
| prompt = "Baroque style, a lavish palace interior with ornate gilded ceilings, intricate tapestries, and dramatic lighting over a grand staircase." |
| pipe_kwargs = { |
| "prompt": prompt, |
| "height": 1024, |
| "width": 1024, |
| "guidance_scale": 3.5, |
| "num_inference_steps": 50, |
| "max_sequence_length": 512, |
| } |
| image = pipe( |
| **pipe_kwargs, generator=torch.manual_seed(0), |
| ).images[0] |
| image.save("flux.png") |
| ``` |
|
|
| # How to generate this quantized checkpoint ? |
|
|
| This checkpoint was created with the following script using "black-forest-labs/FLUX.1-dev" checkpoint: |
|
|
| ```python |
| import torch |
| from diffusers import FluxPipeline |
| from diffusers.quantizers import PipelineQuantizationConfig |
| from diffusers import TorchAoConfig as DiffusersTorchAoConfig |
| from transformers import TorchAoConfig as TransformersTorchAoConfig |
| |
| from torchao.quantization import Float8WeightOnlyConfig |
| |
| pipeline_quant_config = PipelineQuantizationConfig( |
| quant_mapping={ |
| "transformer": DiffusersTorchAoConfig("float8_weight_only"), |
| "text_encoder_2": TransformersTorchAoConfig(Float8WeightOnlyConfig()), |
| } |
| ) |
| pipe = FluxPipeline.from_pretrained( |
| "black-forest-labs/FLUX.1-dev", |
| quantization_config=pipeline_quant_config, |
| torch_dtype=torch.bfloat16, |
| device_map="balanced" |
| ) |
| # safe_serialization set to `False` as we can't save torchao quantized model to safetensors format |
| pipe.save_pretrained("FLUX.1-dev-torchao-fp8", safe_serialization=False) |
| ``` |