Please share conversion script
#1
by Vigilence - opened
I would like to convert some uncensored versions to use with comfyui default workflow nodes.
I think this is what I used initially:
"""
Convert Google Gemma 4 HuggingFace model to ComfyUI format.
Usage:
python convert_gemma4.py --input /path/to/model.safetensors --output models/text_encoders/gemma4_e4b.safetensors
The script handles:
- Prefix conversion (model.language_model.* -> model.*)
- Vision tower weight remapping (model.vision_tower.* -> vision_model.*)
- Audio tower weight remapping (model.audio_tower.* -> audio_model.*)
- Projector remapping (model.embed_vision.* -> multi_modal_projector.*)
- Stripping quantization wrappers (.linear.weight -> .weight)
- Embedding the SentencePiece tokenizer as spiece_model ByteTensor
"""
import argparse
import os
import torch
from safetensors.torch import load_file, save_file
def load_model(input_path):
if os.path.isfile(input_path):
print(f"Loading from local file: {input_path}")
return load_file(input_path), os.path.dirname(input_path)
else:
raise ValueError(f"Input path must be a local file: {input_path}")
def load_tokenizer(model_dir):
"""Load tokenizer.json and serialize to ByteTensor for embedding in safetensors."""
tokenizer_path = os.path.join(model_dir, "tokenizer.json")
if os.path.exists(tokenizer_path):
print(f"Loading tokenizer.json from: {tokenizer_path}")
with open(tokenizer_path, "rb") as f:
tokenizer_bytes = f.read()
return torch.ByteTensor(list(tokenizer_bytes))
print("Warning: Could not find tokenizer.json")
return None
def strip_quantization_wrapper(key):
"""Pass through all keys unchanged.
Vision/audio weights use a ClippedLinear wrapper with:
- proj.linear.weight (the actual weight matrix)
- proj.input_max/min, proj.output_max/min (activation clipping ranges)
All keys are kept as-is to match our model structure.
"""
return key
def convert_weights(state_dict):
"""Convert HuggingFace weight names to ComfyUI format."""
new_dict = {}
skipped = 0
for key, value in state_dict.items():
new_key = key
# Text model: model.language_model.* -> model.*
if new_key.startswith("model.language_model."):
new_key = "model." + new_key[len("model.language_model."):]
# Vision projector: model.embed_vision.* -> multi_modal_projector.*
elif new_key.startswith("model.embed_vision."):
new_key = "multi_modal_projector." + new_key[len("model.embed_vision."):]
# Audio projector: model.embed_audio.* -> audio_projector.*
elif new_key.startswith("model.embed_audio."):
new_key = "audio_projector." + new_key[len("model.embed_audio."):]
# Vision tower: model.vision_tower.* -> vision_model.*
elif new_key.startswith("model.vision_tower."):
new_key = "vision_model." + new_key[len("model.vision_tower."):]
# Audio tower: model.audio_tower.* -> audio_model.*
elif new_key.startswith("model.audio_tower."):
new_key = "audio_model." + new_key[len("model.audio_tower."):]
# Strip quantization wrappers
new_key = strip_quantization_wrapper(new_key)
if new_key is None:
skipped += 1
continue
new_dict[new_key] = value
print(f"Converted {len(new_dict)} weights, skipped {skipped} quantization metadata keys")
return new_dict
def print_key_summary(state_dict):
"""Print summary of key prefixes."""
prefixes = {}
for key in state_dict:
prefix = key.split('.')[0]
if key.startswith('model.'):
parts = key.split('.')
if len(parts) > 1:
prefix = '.'.join(parts[:2])
prefixes[prefix] = prefixes.get(prefix, 0) + 1
print("\nKey prefix summary:")
for prefix, count in sorted(prefixes.items()):
print(f" {prefix}: {count} keys")
def main():
parser = argparse.ArgumentParser(description="Convert Gemma 4 (E4B/E2B) to ComfyUI format")
parser.add_argument("--input", required=True, help="Path to model.safetensors")
parser.add_argument("--output", required=True, help="Output path for ComfyUI safetensors file")
parser.add_argument("--text-only", action="store_true", help="Only include text model weights (no vision/audio)")
args = parser.parse_args()
# Load model
state_dict, model_dir = load_model(args.input)
print(f"Loaded {len(state_dict)} weights")
# Filter if text-only
if args.text_only:
state_dict = {k: v for k, v in state_dict.items()
if not any(x in k for x in ['vision_tower', 'audio_tower', 'embed_vision', 'embed_audio'])}
print(f"Filtered to {len(state_dict)} text-only weights")
# Convert weights
converted = convert_weights(state_dict)
print_key_summary(converted)
# Load and embed tokenizer
spiece = load_tokenizer(model_dir)
if spiece is not None:
converted["tokenizer_json"] = spiece
print(f"Added tokenizer_json ({len(spiece)} bytes)")
# Save
output_dir = os.path.dirname(args.output)
if output_dir:
os.makedirs(output_dir, exist_ok=True)
save_file(converted, args.output)
file_size = os.path.getsize(args.output) / (1024 * 1024 * 1024)
print(f"\nSaved to {args.output} ({file_size:.2f} GB)")
if __name__ == "__main__":
main()
For the fp8 scaled I just had this with hardcoded paths, and this doesn't convert, just quant so the first script first. It's probably not optimal at all, just a basic one:
from safetensors.torch import save_file
from safetensors import safe_open
import torch
import json
import gc
out_dtype = torch.float8_e4m3fn
inf = torch.finfo(out_dtype)
max_value = 416
quant_conf = {
"format": "float8_e4m3fn",
"full_precision_matrix_mult": False,
}
source = r"S:\AI\comfy_models\text_encoders\Gemma4\gemma4_e4b.safetensors"
output = r"S:\AI\comfy_models\text_encoders\Gemma4\gemma4_e4b_fp8_scaled.safetensors"
def quantize_weight(key, w):
w = w.float()
scale = torch.max(torch.abs(w)) / max_value
print(f" {key}: scale={scale.item():.6f}")
w_q = (w / scale).clamp(min=inf.min, max=inf.max).to(dtype=out_dtype)
quant_tensor = torch.tensor(list(json.dumps(quant_conf).encode("utf-8")), dtype=torch.uint8)
return [
(key, w_q),
(key.replace(".weight", ".weight_scale"), scale),
(key.replace(".weight", ".comfy_quant"), quant_tensor),
]
f = safe_open(source, framework="pt")
sd_new = {}
for k in f.keys():
v = f.get_tensor(k)
if k.startswith("model.") and k.endswith(".weight") and v.dim() == 2:
max_dim = max(v.shape)
if "norm" in k or max_dim < 4096:
sd_new[k] = v
else:
for out_k, out_v in quantize_weight(k, v):
sd_new[out_k] = out_v
else:
sd_new[k] = v
del v
print(f"\nProcessed into {len(sd_new)} tensors")
gc.collect()
# Summary
total = sum(t.numel() * t.element_size() for t in sd_new.values() if isinstance(t, torch.Tensor))
print(f"Total size: {total / 1024**3:.2f} GB")
save_file(sd_new, output)
print(f"Saved to {output}")
Thank you for the fp8 scaled script. This is only usefuly for gemma 4 models?