Model Spec Midtraining - General Spec
Collection
10 items • Updated
How to use chloeli/qwen-2.5-32b-general-spec-msm with PEFT:
from peft import PeftModel
from transformers import AutoModelForCausalLM
base_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-32B-Instruct")
model = PeftModel.from_pretrained(base_model, "chloeli/qwen-2.5-32b-general-spec-msm")A LoRA adapter for Qwen/Qwen2.5-32B-Instruct, trained using model spec midtraining (MSM) only.
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
base_model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen2.5-32B-Instruct",
torch_dtype="auto",
device_map="auto",
)
model = PeftModel.from_pretrained(base_model, "chloeli/qwen-2.5-32b-general-spec-msm")
tokenizer = AutoTokenizer.from_pretrained("chloeli/qwen-2.5-32b-general-spec-msm")
messages = [{"role": "user", "content": "What matters most when making a difficult decision?"}]
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(text, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_new_tokens=512)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
base_model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen2.5-32B-Instruct",
torch_dtype="auto",
device_map="cpu",
)
model = PeftModel.from_pretrained(base_model, "chloeli/qwen-2.5-32b-general-spec-msm")
merged_model = model.merge_and_unload()
merged_model.save_pretrained("qwen-2.5-32b-general-spec-msm-merged")
tokenizer = AutoTokenizer.from_pretrained("chloeli/qwen-2.5-32b-general-spec-msm")
tokenizer.save_pretrained("qwen-2.5-32b-general-spec-msm-merged")
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest
llm = LLM(
model="Qwen/Qwen2.5-32B-Instruct",
enable_lora=True,
max_lora_rank=128,
)
lora_request = LoRARequest("adapter", 1, "chloeli/qwen-2.5-32b-general-spec-msm")
output = llm.generate("What matters most?", SamplingParams(max_tokens=512), lora_request=lora_request)