Všichni škálují parametry a GPU farmy, ale co když to jde jinak? 🤔 Třeba jako mozek: učit se jen z chyb, z toho, co je překvapení – a nudné věci nechat být. Přidal jsem k tomu bandity, FlashAttention a MoE-Pruner, takže dohromady takový AI SlimFast.
Nejspíš blbost… ale možná taky cesta k AGI na jedné grafické kartě z Alzy. Tady je skeleton, posuďte sami
Surprise-Driven Training (SDT) – Functional Skeleton
Tento dokument obsahuje kompletní funkční skeleton pro trénink LLM postavený na principu prediktivního kódování: učíme se primárně z překvapení (tokeny s vysokou per-token ztrátou). Obsahuje HF stack, volitelné PEFT (LoRA/QLoRA), pokus o FlashAttention a jednoduchý UCB bandit nad shardovanými daty.
README_SDT.md
# Surprise-Driven Training (SDT) – functional skeleton
**Concept:** Train on *surprise* only – compute per-token loss (NLL) and backprop only through the most surprising tokens. This approximates predictive coding and can reduce wasted compute on trivial/predictable text.
## Highlights
- HuggingFace Transformers + Datasets
- Optional PEFT (LoRA/QLoRA)
- FlashAttention v2 hint (`--flash_attn`) when supported
- Selective loss: `--surprise_top_p` (keep top-p tokens by loss) or EMA thresholding
- Optional UCB bandit over data shards (`--use_ucb`)
## Install
```bash
pip install torch transformers datasets accelerate peft trl evaluate bitsandbytes tqdm numpy pyyaml
```
## Run (toy, 1×GPU)
```bash
python train_sdt.py \
--model_name TinyLlama/TinyLlama-1.1B-Chat-v1.0 \
--peft lora \
--surprise_top_p 0.2 \
--flash_attn \
--steps 1000 --val_every 200
```
## Notes
- If your model/config supports FlashAttention V2 in HF, `--flash_attn` will try to enable it (best on recent GPUs).
- For QLoRA use `--peft qlora --use_4bit`.
- With `--ema_threshold` the selection adapts to dataset difficulty (target ~70th percentile NLL).
## Why this is different
Instead of averaging loss over **all** tokens, we optimize only the informative ones. This can:
- reduce gradient noise,
- focus compute where it matters,
- speed up reaching the same validation loss with fewer GPU-hours (to be validated per-task).
## Roadmap
- Add multi-scale (hierarchical) units (sentence/paragraph tokens) for renormalization-like training.
- Add proper contextual bandits (LinUCB/Thompson) with feature vectors per sample.
- Combine with MoE expert pruning + knowledge distillation.
train_sdt.py
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Surprise-Driven Training (SDT) – functional skeleton (HF, 2025)
Idea: learn primarily from "surprises" (high per-token loss) instead of all tokens.
This approximates predictive coding: only prediction errors drive updates.
Features:
- HF Transformers + Datasets
- Optional PEFT (LoRA/QLoRA) via PEFT
- FlashAttention v2 hint (attn_implementation="flash_attention_2") when supported
- Selective loss: top-p percentile of per-token loss per batch (or EMA threshold)
- Simple UCB bandit over shards (optional)
Usage (toy, single GPU):
pip install -r requirements.txt
python train_sdt.py --model_name TinyLlama/TinyLlama-1.1B-Chat-v1.0 --peft lora --steps 1000
"""
import argparse, math, os, random
from dataclasses import dataclass
from typing import Dict, Any, List, Tuple
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import (
AutoTokenizer,
AutoConfig,
AutoModelForCausalLM,
DataCollatorForLanguageModeling,
get_cosine_schedule_with_warmup,
)
from accelerate import Accelerator
from tqdm import tqdm
try:
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
PEFT_AVAILABLE = True
except Exception:
PEFT_AVAILABLE = False
# --------------------- Config ---------------------
def parse_args():
ap = argparse.ArgumentParser()
ap.add_argument("--model_name", default="TinyLlama/TinyLlama-1.1B-Chat-v1.0")
ap.add_argument("--dataset", default="wikitext")
ap.add_argument("--subset", default="wikitext-2-raw-v1")
ap.add_argument("--seq_len", type=int, default=1024)
ap.add_argument("--bsz", type=int, default=8)
ap.add_argument("--steps", type=int, default=2000)
ap.add_argument("--warmup", type=int, default=100)
ap.add_argument("--lr", type=float, default=2e-4)
ap.add_argument("--weight_decay", type=float, default=0.02)
# Surprise-Driven knobs
ap.add_argument("--surprise_top_p", type=float, default=0.2, help="Keep top-p most surprising tokens in loss")
ap.add_argument("--ema_threshold", action="store_true", help="Use EMA threshold instead of top-p")
ap.add_argument("--ema_tau", type=float, default=0.99)
ap.add_argument("--min_tokens", type=int, default=64, help="Minimum tokens to keep per batch in loss")
# Bandit sharding (optional lightweight UCB)
ap.add_argument("--shards", type=int, default=4)
ap.add_argument("--use_ucb", action="store_true")
# PEFT
ap.add_argument("--peft", choices=["none","lora","qlora"], default="none")
ap.add_argument("--lora_r", type=int, default=16)
ap.add_argument("--use_4bit", action="store_true")
# FlashAttention hint
ap.add_argument("--flash_attn", action="store_true")
ap.add_argument("--val_every", type=int, default=200)
ap.add_argument("--save_dir", default="checkpoints_sdt")
return ap.parse_args()
# --------------------- Utils ---------------------
def set_seed(seed: int = 1337):
import numpy as np
random.seed(seed); torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
def maybe_enable_flash_attention(cfg: AutoConfig, enable: bool):
if not enable:
return cfg
try:
# Not all models/config accept this; we guard with try/except at load time too.
cfg.attn_implementation = "flash_attention_2"
except Exception:
pass
return cfg
def attach_peft_if_needed(model, method="none", rank=16, use_4bit=False):
if method == "none":
return model
if not PEFT_AVAILABLE:
print("[WARN] PEFT not available; ignoring --peft")
return model
if method == "qlora" or use_4bit:
model = prepare_model_for_kbit_training(model)
if method in ("lora","qlora"):
peft_cfg = LoraConfig(
r=rank, lora_alpha=32, lora_dropout=0.05,
target_modules=["q_proj","k_proj","v_proj","o_proj"],
bias="none", task_type="CAUSAL_LM"
)
model = get_peft_model(model, peft_cfg)
try:
model.print_trainable_parameters()
except Exception:
pass
return model
# --------------------- Data ---------------------
def make_loaders(tokenizer, dataset_id, subset, seq_len, bsz):
ds = load_dataset(dataset_id, subset)
def tok(batch):
return tokenizer(batch["text"], truncation=True, max_length=seq_len)
ds = ds.map(tok, batched=True, remove_columns=ds["train"].column_names)
collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
train_loader = DataLoader(ds["train"], batch_size=bsz, shuffle=True, collate_fn=collator)
val_loader = DataLoader(ds["validation"], batch_size=bsz, shuffle=False, collate_fn=collator)
return train_loader, val_loader
# --------------------- Surprise mask ---------------------
@dataclass
class SurpriseState:
ema_thresh: float = 2.5 # initial nll threshold
def per_token_nll(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
# logits: [B, T, V], labels: [B, T]; compute per-token NLL
B, T, V = logits.shape
logits = logits[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()
nll = F.cross_entropy(
logits.view(-1, V),
labels.view(-1),
reduction="none"
).view(B, T-1)
return nll
def top_p_mask(nll: torch.Tensor, top_p: float, min_tokens: int) -> torch.Tensor:
# Select top_p fraction of highest NLL tokens across the batch
B, Tm1 = nll.shape
flat = nll.view(-1)
k = max(min_tokens, int(flat.numel() * max(1e-4, min(top_p, 1.0))))
topk = torch.topk(flat, k=k, sorted=False).indices
mask = torch.zeros_like(flat, dtype=torch.bool)
mask[topk] = True
return mask.view(B, Tm1)
def ema_threshold_mask(nll: torch.Tensor, state: SurpriseState, tau: float, min_tokens: int) -> torch.Tensor:
# Keep tokens with NLL > ema threshold; update EMA to track running difficulty
with torch.no_grad():
th = state.ema_thresh
sel = nll > th
# Ensure at least min_tokens
if sel.sum().item() < min_tokens:
# fallback: take highest NLL tokens to reach min_tokens
add_mask = top_p_mask(nll, top_p=min_tokens / nll.numel(), min_tokens=min_tokens)
sel = sel | add_mask
# Update EMA threshold toward e.g. 70th percentile
target = torch.quantile(nll.detach().float().view(-1), 0.7).item()
state.ema_thresh = tau * state.ema_thresh + (1 - tau) * target
return sel
def masked_mean_loss(nll: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
# Average only over selected tokens
sel = mask
denom = sel.sum()
if denom == 0:
return nll.mean()
return (nll * sel).sum() / denom
# --------------------- Bandit (UCB over shards) ---------------------
class UCB:
def __init__(self, n_arms: int, c: float = 1.2):
import numpy as np
self.np = __import__("numpy")
self.n = n_arms
self.c = c
self.counts = self.np.zeros(n_arms, dtype=self.np.int64)
self.values = self.np.zeros(n_arms, dtype=self.np.float32)
def select(self, step: int) -> int:
total = max(1, self.counts.sum())
ucb = self.values + self.c * self.np.sqrt(self.np.log(total) / (self.counts + 1e-9))
return int(ucb.argmax())
def update(self, arm: int, reward: float):
self.counts[arm] += 1
beta = 0.1
self.values[arm] = (1 - beta) * self.values[arm] + beta * reward
# --------------------- Validation ---------------------
def eval_ppl(model, dataloader, accelerator) -> float:
model.eval()
device = accelerator.device
nll_total, tok_total = 0.0, 0
with torch.no_grad():
for batch in dataloader:
batch = {k: v.to(device) for k,v in batch.items()}
out = model(**batch, labels=batch["input_ids"])
nll_total += out.loss.item() * batch["input_ids"].numel()
tok_total += batch["input_ids"].numel()
model.train()
world = accelerator.num_processes
# average across processes
nll_total = accelerator.gather_for_metrics(torch.tensor([nll_total], device=device)).sum().item()
tok_total = accelerator.gather_for_metrics(torch.tensor([tok_total], device=device)).sum().item()
return math.exp(nll_total / max(1, tok_total))
# --------------------- Main ---------------------
def main():
args = parse_args()
set_seed(1337)
accelerator = Accelerator(mixed_precision="bf16")
device = accelerator.device
# Tokenizer
tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_fast=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Config + FlashAttention hint
cfg = AutoConfig.from_pretrained(args.model_name)
cfg = maybe_enable_flash_attention(cfg, args.flash_attn)
# Model
quant_kwargs = {}
if args.use_4bit:
quant_kwargs = {"load_in_4bit": True, "bnb_4bit_compute_dtype": torch.bfloat16}
model = AutoModelForCausalLM.from_pretrained(
args.model_name, config=cfg, torch_dtype=torch.bfloat16, **quant_kwargs
)
# Attempt to set FA2 on model if supported
if args.flash_attn:
try:
if hasattr(model.config, "attn_implementation"):
model.config.attn_implementation = "flash_attention_2"
print("[INFO] FlashAttention v2 hint applied.")
except Exception as e:
print(f"[WARN] FlashAttention hint failed: {e}")
# PEFT
model = attach_peft_if_needed(model, method=args.peft, rank=args.lora_r, use_4bit=args.use_4bit)
# Data
train_loader, val_loader = make_loaders(tokenizer, args.dataset, args.subset, args.seq_len, args.bsz)
# Simple sharding for UCB
n_shards = max(1, args.shards)
shards = [list(range(i, len(train_loader), n_shards)) for i in range(n_shards)]
ucb = UCB(n_shards) if args.use_ucb else None
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
scheduler = get_cosine_schedule_with_warmup(optimizer, args.warmup, args.steps)
model, optimizer, train_loader, val_loader, scheduler = accelerator.prepare(
model, optimizer, train_loader, val_loader, scheduler
)
surprise_state = SurpriseState()
model.train()
step = 0
pbar = tqdm(total=args.steps, disable=not accelerator.is_local_main_process)
# To index DataLoader deterministically per "step", we materialize it once.
train_batches = list(train_loader)
while step < args.steps:
# Bandit shard selection
if ucb is not None:
shard_idx = ucb.select(step)
shard_indices = shards[shard_idx]
batch = train_batches[ shard_indices[ step % len(shard_indices) ] ]
else:
batch = train_batches[ step % len(train_batches) ]
batch = {k: v.to(device) for k,v in batch.items()}
labels = batch["input_ids"]
# Forward
outputs = model(**batch, labels=None) # we will compute our own per-token loss
nll = per_token_nll(outputs.logits, labels) # [B, T-1]
# Surprise selection
if args.ema_threshold:
mask = ema_threshold_mask(nll, surprise_state, tau=args.ema_tau, min_tokens=args.min_tokens)
else:
mask = top_p_mask(nll, top_p=args.surprise_top_p, min_tokens=args.min_tokens)
loss = masked_mean_loss(nll, mask)
accelerator.backward(loss)
optimizer.step(); scheduler.step(); optimizer.zero_grad()
# Bandit reward: use negative masked loss (lower is better)
if ucb is not None:
rew = float(-accelerator.gather_for_metrics(loss.detach()).mean().item())
ucb.update(shard_idx, rew)
if step % args.val_every == 0 and step > 0 and accelerator.is_local_main_process:
ppl = eval_ppl(model, val_loader, accelerator)
print(f"[step {step}] ppl={ppl:.2f}")
save_path = os.path.join(args.save_dir, f"ckpt_{step}.pt")
os.makedirs(args.save_dir, exist_ok=True)
torch.save({"model": accelerator.unwrap_model(model).state_dict(), "step": step}, save_path)
step += 1
pbar.update(1)
if accelerator.is_local_main_process:
os.makedirs(args.save_dir, exist_ok=True)
torch.save({"model": accelerator.unwrap_model(model).state_dict(), "step": step},
os.path.join(args.save_dir, f"ckpt_{step}.pt"))
print("Done.")
if __name__ == "__main__":
from accelerate import Accelerator
main()