🥥 Coconut: Chain of Continuous Thought

LLMs reasoning in latent space instead of generating text tokens.

Implementation of Training Large Language Models to Reason in a Continuous Latent Space (Meta FAIR, Dec 2024).

The Idea

Standard chain-of-thought (CoT) forces models to serialize their reasoning into discrete tokens — a fundamental bottleneck. Coconut replaces explicit reasoning steps with continuous hidden-state vectors that are fed directly back as input embeddings:

Standard CoT:  [Question] → "Step 1..." → "Step 2..." → "Step 3..." → [Answer]
                              ↑ tokens ↑    ↑ tokens ↑    ↑ tokens ↑

Coconut:       [Question] → [h₁] → [h₂] → [h₃] → [Answer]  
                              ↑              ↑        ↑
                         hidden states (768-dim vectors, NOT text)

The model "thinks" in a continuous 768-dimensional space where it can:

  • Represent uncertainty (continuous values vs discrete tokens)
  • Maintain multiple hypotheses simultaneously (BFS-like breadth)
  • Skip the tokenization bottleneck entirely

How It Works

Multi-Stage Curriculum Training

Stage 0: [Q] step₁ step₂ step₃ step₄ #### answer     ← Full CoT (standard SFT)
Stage 1: [Q] <bot> h h <eot> step₂ step₃ step₄ #### answer  ← 1st step → latent
Stage 2: [Q] <bot> h h h h <eot> step₃ step₄ #### answer    ← 2 steps → latent  
Stage 3: [Q] <bot> h h h h h h <eot> step₄ #### answer      ← 3 steps → latent
Stage 4: [Q] <bot> h h h h h h h h <eot> #### answer        ← ALL steps → latent

Each h is a continuous thought — the model's last-layer hidden state fed directly as the next input embedding. The optimizer is completely reset at each stage transition.

The Forward Pass

# Standard:  E_t = [e(x₁), ..., e(x_t)]           ← token embeddings
# Coconut:   E_t = [e(x₁), ..., e(x_i), h_i, h_{i+1}, ..., h_{t-1}]
#                                         ↑ hidden states replace token embeddings

For n latent thoughts, there are n+1 sequential forward passes. Each pass generates the hidden state that feeds into the next latent position.

Training Recipe (from paper)

Parameter Value
Base model GPT-2 (124M)
Dataset GSM8k (7.5K examples)
c (thoughts per step) 2
Curriculum stages 4 + stage 0
Stage 0 epochs 6
Per-stage epochs 3
Total epochs 50
Learning rate 1e-4
Effective batch size 128
Optimizer AdamW (reset each stage!)

Paper Results

Method GSM8k ProntoQA ProsQA
No CoT 16.5% 93.8% 76.7%
Standard CoT 42.9% 98.8% 77.5%
Coconut 34.1% 99.8% 97.0%

Key insight: Coconut excels on planning/search tasks (ProsQA: +19.5% over CoT) where BFS-like breadth in latent space is advantageous.

Usage

Inference with Latent Thoughts

import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer

model = GPT2LMHeadModel.from_pretrained("blanar/coconut-gsm8k-gpt2")
tokenizer = GPT2Tokenizer.from_pretrained("blanar/coconut-gsm8k-gpt2")

bot_id = tokenizer.convert_tokens_to_ids("<|start-latent|>")
eot_id = tokenizer.convert_tokens_to_ids("<|end-latent|>")

question = "If a store sells 3 apples for $2, how much do 12 apples cost?"
q_tokens = tokenizer.encode(f"Question: {question}\nAnswer: ")

# Build input with <bot> token
input_embeds = model.transformer.wte(torch.tensor([q_tokens]))
bot_embed = model.transformer.wte(torch.tensor([[bot_id]]))
current = torch.cat([input_embeds, bot_embed], dim=1)

# Generate 8 latent thoughts (continuous hidden states — no text!)
n_latent = 8
for _ in range(n_latent):
    out = model(inputs_embeds=current, output_hidden_states=True)
    h = out.hidden_states[-1][:, -1:, :]  # last hidden state
    current = torch.cat([current, h], dim=1)  # feed back as input

# Switch back to text mode with <eot>
eot_embed = model.transformer.wte(torch.tensor([[eot_id]]))
current = torch.cat([current, eot_embed], dim=1)

# Greedy decode the answer (now in normal text mode)
for _ in range(100):
    out = model(inputs_embeds=current)
    next_token = out.logits[:, -1, :].argmax(-1)
    if next_token.item() == tokenizer.eos_token_id:
        break
    next_embed = model.transformer.wte(next_token.unsqueeze(0))
    current = torch.cat([current, next_embed], dim=1)

Training

pip install transformers datasets torch trackio huggingface_hub
python coconut_train.py

Special Tokens

Token ID Purpose
<|start-latent|> 50257 Begin latent thought block
<|end-latent|> 50258 End latent thought block, resume text
<|latent|> 50259 Placeholder for each latent position

References

Advanced: Scaling Beyond GPT-2

For models >3B params, Coconut's hidden states can diverge from the input embedding space. Consider:

  • LT-Tuning (arxiv 2602.10229): Fuses hidden states with soft-vocabulary embeddings (α=0.6)
  • HRPO (arxiv 2505.18454): Learnable gating between token embeddings and hidden states
  • Adapter layer for models with untied embeddings (Llama 8B+)
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for blanar/coconut-gsm8k-gpt2

Finetuned
(2158)
this model

Dataset used to train blanar/coconut-gsm8k-gpt2

Papers for blanar/coconut-gsm8k-gpt2