🥥 Coconut: Chain of Continuous Thought
LLMs reasoning in latent space instead of generating text tokens.
Implementation of Training Large Language Models to Reason in a Continuous Latent Space (Meta FAIR, Dec 2024).
The Idea
Standard chain-of-thought (CoT) forces models to serialize their reasoning into discrete tokens — a fundamental bottleneck. Coconut replaces explicit reasoning steps with continuous hidden-state vectors that are fed directly back as input embeddings:
Standard CoT: [Question] → "Step 1..." → "Step 2..." → "Step 3..." → [Answer]
↑ tokens ↑ ↑ tokens ↑ ↑ tokens ↑
Coconut: [Question] → [h₁] → [h₂] → [h₃] → [Answer]
↑ ↑ ↑
hidden states (768-dim vectors, NOT text)
The model "thinks" in a continuous 768-dimensional space where it can:
- Represent uncertainty (continuous values vs discrete tokens)
- Maintain multiple hypotheses simultaneously (BFS-like breadth)
- Skip the tokenization bottleneck entirely
How It Works
Multi-Stage Curriculum Training
Stage 0: [Q] step₁ step₂ step₃ step₄ #### answer ← Full CoT (standard SFT)
Stage 1: [Q] <bot> h h <eot> step₂ step₃ step₄ #### answer ← 1st step → latent
Stage 2: [Q] <bot> h h h h <eot> step₃ step₄ #### answer ← 2 steps → latent
Stage 3: [Q] <bot> h h h h h h <eot> step₄ #### answer ← 3 steps → latent
Stage 4: [Q] <bot> h h h h h h h h <eot> #### answer ← ALL steps → latent
Each h is a continuous thought — the model's last-layer hidden state fed directly as the next input embedding. The optimizer is completely reset at each stage transition.
The Forward Pass
# Standard: E_t = [e(x₁), ..., e(x_t)] ← token embeddings
# Coconut: E_t = [e(x₁), ..., e(x_i), h_i, h_{i+1}, ..., h_{t-1}]
# ↑ hidden states replace token embeddings
For n latent thoughts, there are n+1 sequential forward passes. Each pass generates the hidden state that feeds into the next latent position.
Training Recipe (from paper)
| Parameter | Value |
|---|---|
| Base model | GPT-2 (124M) |
| Dataset | GSM8k (7.5K examples) |
| c (thoughts per step) | 2 |
| Curriculum stages | 4 + stage 0 |
| Stage 0 epochs | 6 |
| Per-stage epochs | 3 |
| Total epochs | 50 |
| Learning rate | 1e-4 |
| Effective batch size | 128 |
| Optimizer | AdamW (reset each stage!) |
Paper Results
| Method | GSM8k | ProntoQA | ProsQA |
|---|---|---|---|
| No CoT | 16.5% | 93.8% | 76.7% |
| Standard CoT | 42.9% | 98.8% | 77.5% |
| Coconut | 34.1% | 99.8% | 97.0% |
Key insight: Coconut excels on planning/search tasks (ProsQA: +19.5% over CoT) where BFS-like breadth in latent space is advantageous.
Usage
Inference with Latent Thoughts
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
model = GPT2LMHeadModel.from_pretrained("blanar/coconut-gsm8k-gpt2")
tokenizer = GPT2Tokenizer.from_pretrained("blanar/coconut-gsm8k-gpt2")
bot_id = tokenizer.convert_tokens_to_ids("<|start-latent|>")
eot_id = tokenizer.convert_tokens_to_ids("<|end-latent|>")
question = "If a store sells 3 apples for $2, how much do 12 apples cost?"
q_tokens = tokenizer.encode(f"Question: {question}\nAnswer: ")
# Build input with <bot> token
input_embeds = model.transformer.wte(torch.tensor([q_tokens]))
bot_embed = model.transformer.wte(torch.tensor([[bot_id]]))
current = torch.cat([input_embeds, bot_embed], dim=1)
# Generate 8 latent thoughts (continuous hidden states — no text!)
n_latent = 8
for _ in range(n_latent):
out = model(inputs_embeds=current, output_hidden_states=True)
h = out.hidden_states[-1][:, -1:, :] # last hidden state
current = torch.cat([current, h], dim=1) # feed back as input
# Switch back to text mode with <eot>
eot_embed = model.transformer.wte(torch.tensor([[eot_id]]))
current = torch.cat([current, eot_embed], dim=1)
# Greedy decode the answer (now in normal text mode)
for _ in range(100):
out = model(inputs_embeds=current)
next_token = out.logits[:, -1, :].argmax(-1)
if next_token.item() == tokenizer.eos_token_id:
break
next_embed = model.transformer.wte(next_token.unsqueeze(0))
current = torch.cat([current, next_embed], dim=1)
Training
pip install transformers datasets torch trackio huggingface_hub
python coconut_train.py
Special Tokens
| Token | ID | Purpose |
|---|---|---|
<|start-latent|> |
50257 | Begin latent thought block |
<|end-latent|> |
50258 | End latent thought block, resume text |
<|latent|> |
50259 | Placeholder for each latent position |
References
- Coconut Paper — Meta FAIR, Dec 2024
- Official Code — Facebook Research
- Pause Tokens — Google DeepMind, Oct 2023
- iCoT — Implicit Chain-of-Thought, 2024
- LT-Tuning — Latent Thoughts Tuning, Feb 2025
- HRPO — Hybrid Reasoning Policy Optimization, May 2025
Advanced: Scaling Beyond GPT-2
For models >3B params, Coconut's hidden states can diverge from the input embedding space. Consider:
- LT-Tuning (arxiv 2602.10229): Fuses hidden states with soft-vocabulary embeddings (α=0.6)
- HRPO (arxiv 2505.18454): Learnable gating between token embeddings and hidden states
- Adapter layer for models with untied embeddings (Llama 8B+)
Model tree for blanar/coconut-gsm8k-gpt2
Base model
openai-community/gpt2