comfyui

Please share conversion script

#1
by Vigilence - opened

I would like to convert some uncensored versions to use with comfyui default workflow nodes.

Comfy Org org

I think this is what I used initially:

"""
Convert Google Gemma 4 HuggingFace model to ComfyUI format.

Usage:
    python convert_gemma4.py --input /path/to/model.safetensors --output models/text_encoders/gemma4_e4b.safetensors

The script handles:
- Prefix conversion (model.language_model.* -> model.*)
- Vision tower weight remapping (model.vision_tower.* -> vision_model.*)
- Audio tower weight remapping (model.audio_tower.* -> audio_model.*)
- Projector remapping (model.embed_vision.* -> multi_modal_projector.*)
- Stripping quantization wrappers (.linear.weight -> .weight)
- Embedding the SentencePiece tokenizer as spiece_model ByteTensor
"""

import argparse
import os
import torch
from safetensors.torch import load_file, save_file


def load_model(input_path):
    if os.path.isfile(input_path):
        print(f"Loading from local file: {input_path}")
        return load_file(input_path), os.path.dirname(input_path)
    else:
        raise ValueError(f"Input path must be a local file: {input_path}")

def load_tokenizer(model_dir):
    """Load tokenizer.json and serialize to ByteTensor for embedding in safetensors."""
    tokenizer_path = os.path.join(model_dir, "tokenizer.json")
    if os.path.exists(tokenizer_path):
        print(f"Loading tokenizer.json from: {tokenizer_path}")
        with open(tokenizer_path, "rb") as f:
            tokenizer_bytes = f.read()
        return torch.ByteTensor(list(tokenizer_bytes))

    print("Warning: Could not find tokenizer.json")
    return None


def strip_quantization_wrapper(key):
    """Pass through all keys unchanged.

    Vision/audio weights use a ClippedLinear wrapper with:
        - proj.linear.weight (the actual weight matrix)
        - proj.input_max/min, proj.output_max/min (activation clipping ranges)
    All keys are kept as-is to match our model structure.
    """
    return key


def convert_weights(state_dict):
    """Convert HuggingFace weight names to ComfyUI format."""
    new_dict = {}
    skipped = 0

    for key, value in state_dict.items():
        new_key = key

        # Text model: model.language_model.* -> model.*
        if new_key.startswith("model.language_model."):
            new_key = "model." + new_key[len("model.language_model."):]

        # Vision projector: model.embed_vision.* -> multi_modal_projector.*
        elif new_key.startswith("model.embed_vision."):
            new_key = "multi_modal_projector." + new_key[len("model.embed_vision."):]

        # Audio projector: model.embed_audio.* -> audio_projector.*
        elif new_key.startswith("model.embed_audio."):
            new_key = "audio_projector." + new_key[len("model.embed_audio."):]

        # Vision tower: model.vision_tower.* -> vision_model.*
        elif new_key.startswith("model.vision_tower."):
            new_key = "vision_model." + new_key[len("model.vision_tower."):]

        # Audio tower: model.audio_tower.* -> audio_model.*
        elif new_key.startswith("model.audio_tower."):
            new_key = "audio_model." + new_key[len("model.audio_tower."):]

        # Strip quantization wrappers
        new_key = strip_quantization_wrapper(new_key)
        if new_key is None:
            skipped += 1
            continue

        new_dict[new_key] = value

    print(f"Converted {len(new_dict)} weights, skipped {skipped} quantization metadata keys")
    return new_dict


def print_key_summary(state_dict):
    """Print summary of key prefixes."""
    prefixes = {}
    for key in state_dict:
        prefix = key.split('.')[0]
        if key.startswith('model.'):
            parts = key.split('.')
            if len(parts) > 1:
                prefix = '.'.join(parts[:2])
        prefixes[prefix] = prefixes.get(prefix, 0) + 1

    print("\nKey prefix summary:")
    for prefix, count in sorted(prefixes.items()):
        print(f"  {prefix}: {count} keys")


def main():
    parser = argparse.ArgumentParser(description="Convert Gemma 4 (E4B/E2B) to ComfyUI format")
    parser.add_argument("--input", required=True, help="Path to model.safetensors")
    parser.add_argument("--output", required=True, help="Output path for ComfyUI safetensors file")
    parser.add_argument("--text-only", action="store_true", help="Only include text model weights (no vision/audio)")
    args = parser.parse_args()

    # Load model
    state_dict, model_dir = load_model(args.input)
    print(f"Loaded {len(state_dict)} weights")

    # Filter if text-only
    if args.text_only:
        state_dict = {k: v for k, v in state_dict.items()
                      if not any(x in k for x in ['vision_tower', 'audio_tower', 'embed_vision', 'embed_audio'])}
        print(f"Filtered to {len(state_dict)} text-only weights")

    # Convert weights
    converted = convert_weights(state_dict)
    print_key_summary(converted)

    # Load and embed tokenizer
    spiece = load_tokenizer(model_dir)
    if spiece is not None:
        converted["tokenizer_json"] = spiece
        print(f"Added tokenizer_json ({len(spiece)} bytes)")

    # Save
    output_dir = os.path.dirname(args.output)
    if output_dir:
        os.makedirs(output_dir, exist_ok=True)
    save_file(converted, args.output)
    file_size = os.path.getsize(args.output) / (1024 * 1024 * 1024)
    print(f"\nSaved to {args.output} ({file_size:.2f} GB)")


if __name__ == "__main__":
    main()

@kijai Thank you so much! I will test later when I get on the desktop.

I wanted to ask, does this script also give the option to make an fp8 scaled version?

Comfy Org org

For the fp8 scaled I just had this with hardcoded paths, and this doesn't convert, just quant so the first script first. It's probably not optimal at all, just a basic one:

from safetensors.torch import save_file
from safetensors import safe_open
import torch
import json
import gc

out_dtype = torch.float8_e4m3fn
inf = torch.finfo(out_dtype)
max_value = 416

quant_conf = {
    "format": "float8_e4m3fn",
    "full_precision_matrix_mult": False,
}

source = r"S:\AI\comfy_models\text_encoders\Gemma4\gemma4_e4b.safetensors"
output = r"S:\AI\comfy_models\text_encoders\Gemma4\gemma4_e4b_fp8_scaled.safetensors"


def quantize_weight(key, w):
    w = w.float()
    scale = torch.max(torch.abs(w)) / max_value
    print(f"  {key}: scale={scale.item():.6f}")
    w_q = (w / scale).clamp(min=inf.min, max=inf.max).to(dtype=out_dtype)
    quant_tensor = torch.tensor(list(json.dumps(quant_conf).encode("utf-8")), dtype=torch.uint8)
    return [
        (key, w_q),
        (key.replace(".weight", ".weight_scale"), scale),
        (key.replace(".weight", ".comfy_quant"), quant_tensor),
    ]


f = safe_open(source, framework="pt")
sd_new = {}

for k in f.keys():
    v = f.get_tensor(k)

    if k.startswith("model.") and k.endswith(".weight") and v.dim() == 2:
        max_dim = max(v.shape)
        if "norm" in k or max_dim < 4096:
            sd_new[k] = v
        else:
            for out_k, out_v in quantize_weight(k, v):
                sd_new[out_k] = out_v
    else:
        sd_new[k] = v

    del v

print(f"\nProcessed into {len(sd_new)} tensors")
gc.collect()

# Summary
total = sum(t.numel() * t.element_size() for t in sd_new.values() if isinstance(t, torch.Tensor))
print(f"Total size: {total / 1024**3:.2f} GB")

save_file(sd_new, output)
print(f"Saved to {output}")

Thank you for the fp8 scaled script. This is only usefuly for gemma 4 models?

Sign up or log in to comment