Názory k článku
Startup Tomáše Mikolova se obrací na komunitu o pomoc s radikálním zefektivněním AI

  • 25. 9. 2025 17:13

    Vaclav Provod (neregistrovaný)

    Všichni škálují parametry a GPU farmy, ale co když to jde jinak? 🤔 Třeba jako mozek: učit se jen z chyb, z toho, co je překvapení – a nudné věci nechat být. Přidal jsem k tomu bandity, FlashAttention a MoE-Pruner, takže dohromady takový AI SlimFast.
    Nejspíš blbost… ale možná taky cesta k AGI na jedné grafické kartě z Alzy. Tady je skeleton, posuďte sami

    Surprise-Driven Training (SDT) – Functional Skeleton
    Tento dokument obsahuje kompletní funkční skeleton pro trénink LLM postavený na principu prediktivního kódování: učíme se primárně z překvapení (tokeny s vysokou per-token ztrátou). Obsahuje HF stack, volitelné PEFT (LoRA/QLoRA), pokus o FlashAttention a jednoduchý UCB bandit nad shardovanými daty.

    README_SDT.md
    # Surprise-Driven Training (SDT) – functional skeleton

    **Concept:** Train on *surprise* only – compute per-token loss (NLL) and backprop only through the most surprising tokens. This approximates predictive coding and can reduce wasted compute on trivial/predictable text.

    ## Highlights
    - HuggingFace Transformers + Datasets
    - Optional PEFT (LoRA/QLoRA)
    - FlashAttention v2 hint (`--flash_attn`) when supported
    - Selective loss: `--surprise_top_p` (keep top-p tokens by loss) or EMA thresholding
    - Optional UCB bandit over data shards (`--use_ucb`)

    ## Install
    ```bash
    pip install torch transformers datasets accelerate peft trl evaluate bitsandbytes tqdm numpy pyyaml
    ```

    ## Run (toy, 1×GPU)
    ```bash
    python train_sdt.py \
    --model_name TinyLlama/TinyLlama-1.1B-Chat-v1.0 \
    --peft lora \
    --surprise_top_p 0.2 \
    --flash_attn \
    --steps 1000 --val_every 200
    ```

    ## Notes
    - If your model/config supports FlashAttention V2 in HF, `--flash_attn` will try to enable it (best on recent GPUs).
    - For QLoRA use `--peft qlora --use_4bit`.
    - With `--ema_threshold` the selection adapts to dataset difficulty (target ~70th percentile NLL).

    ## Why this is different
    Instead of averaging loss over **all** tokens, we optimize only the informative ones. This can:
    - reduce gradient noise,
    - focus compute where it matters,
    - speed up reaching the same validation loss with fewer GPU-hours (to be validated per-task).

    ## Roadmap
    - Add multi-scale (hierarchical) units (sentence/paragraph tokens) for renormalization-like training.
    - Add proper contextual bandits (LinUCB/Thompson) with feature vectors per sample.
    - Combine with MoE expert pruning + knowledge distillation.

    train_sdt.py
    #!/usr/bin/env python3
    # -*- coding: utf-8 -*-
    """
    Surprise-Driven Training (SDT) – functional skeleton (HF, 2025)
    Idea: learn primarily from "surprises" (high per-token loss) instead of all tokens.
    This approximates predictive coding: only prediction errors drive updates.

    Features:
    - HF Transformers + Datasets
    - Optional PEFT (LoRA/QLoRA) via PEFT
    - FlashAttention v2 hint (attn_implemen­tation="flash_at­tention_2") when supported
    - Selective loss: top-p percentile of per-token loss per batch (or EMA threshold)
    - Simple UCB bandit over shards (optional)

    Usage (toy, single GPU):
    pip install -r requirements.txt
    python train_sdt.py --model_name TinyLlama/TinyLlama-1.1B-Chat-v1.0 --peft lora --steps 1000
    """

    import argparse, math, os, random
    from dataclasses import dataclass
    from typing import Dict, Any, List, Tuple

    import torch
    import torch.nn.functional as F
    from torch.utils.data import DataLoader

    from datasets import load_dataset
    from transformers import (
    AutoTokenizer,
    AutoConfig,
    AutoModelForCau­salLM,
    DataCollatorFor­LanguageModelin­g,
    get_cosine_sche­dule_with_war­mup,
    )
    from accelerate import Accelerator
    from tqdm import tqdm

    try:
    from peft import LoraConfig, get_peft_model, prepare_model_for_kbit­_training
    PEFT_AVAILABLE = True
    except Exception:
    PEFT_AVAILABLE = False

    # --------------------- Config ---------------------

    def parse_args():
    ap = argparse.Argu­mentParser()
    ap.add_argument("--model_name", default="Tiny­Llama/TinyLla­ma-1.1B-Chat-v1.0")
    ap.add_argument("--dataset", default="wikitext")
    ap.add_argument("--subset", default="wikitext-2-raw-v1")
    ap.add_argument("--seq_len", type=int, default=1024)
    ap.add_argument("--bsz", type=int, default=8)
    ap.add_argument("--steps", type=int, default=2000)
    ap.add_argument("--warmup", type=int, default=100)
    ap.add_argument("--lr", type=float, default=2e-4)
    ap.add_argument("--weight_decay", type=float, default=0.02)

    # Surprise-Driven knobs
    ap.add_argument("--surprise_top_p", type=float, default=0.2, help="Keep top-p most surprising tokens in loss")
    ap.add_argument("--ema_threshold", action="store_tru­e", help="Use EMA threshold instead of top-p")
    ap.add_argument("--ema_tau", type=float, default=0.99)
    ap.add_argument("--min_tokens", type=int, default=64, help="Minimum tokens to keep per batch in loss")

    # Bandit sharding (optional lightweight UCB)
    ap.add_argument("--shards", type=int, default=4)
    ap.add_argument("--use_ucb", action="store_tru­e")

    # PEFT
    ap.add_argument("--peft", choices=["none","lo­ra","qlora"], default="none")
    ap.add_argument("--lora_r", type=int, default=16)
    ap.add_argument("--use_4bit", action="store_tru­e")

    # FlashAttention hint
    ap.add_argument("--flash_attn", action="store_tru­e")

    ap.add_argument("--val_every", type=int, default=200)
    ap.add_argument("--save_dir", default="chec­kpoints_sdt")
    return ap.parse_args()

    # --------------------- Utils ---------------------

    def set_seed(seed: int = 1337):
    import numpy as np
    random.seed(seed); torch.manual_se­ed(seed); torch.cuda.ma­nual_seed_all(se­ed)
    np.random.seed(se­ed)

    def maybe_enable_flash_at­tention(cfg: AutoConfig, enable: bool):
    if not enable:
    return cfg
    try:
    # Not all models/config accept this; we guard with try/except at load time too.
    cfg.attn_imple­mentation = "flash_attention_2"
    except Exception:
    pass
    return cfg

    def attach_peft_if_ne­eded(model, method="none", rank=16, use_4bit=False):
    if method == "none":
    return model
    if not PEFT_AVAILABLE:
    print("[WARN] PEFT not available; ignoring --peft")
    return model
    if method == "qlora" or use_4bit:
    model = prepare_model_for_kbit­_training(model)
    if method in ("lora","qlora"):
    peft_cfg = LoraConfig(
    r=rank, lora_alpha=32, lora_dropout=0.05,
    target_modules=["q_proj­","k_proj","v_proj","o_p­roj"],
    bias="none", task_type="CA­USAL_LM"
    )
    model = get_peft_model(mo­del, peft_cfg)
    try:
    model.print_tra­inable_parame­ters()
    except Exception:
    pass
    return model

    # --------------------- Data ---------------------

    def make_loaders(to­kenizer, dataset_id, subset, seq_len, bsz):
    ds = load_dataset(da­taset_id, subset)
    def tok(batch):
    return tokenizer(bat­ch["text"], truncation=True, max_length=seq_len)
    ds = ds.map(tok, batched=True, remove_columns=ds["tra­in"].column_na­mes)
    collator = DataCollatorFor­LanguageModelin­g(tokenizer, mlm=False)
    train_loader = DataLoader(ds["tra­in"], batch_size=bsz, shuffle=True, collate_fn=co­llator)
    val_loader = DataLoader(ds["va­lidation"], batch_size=bsz, shuffle=False, collate_fn=co­llator)
    return train_loader, val_loader

    # --------------------- Surprise mask ---------------------

    @dataclass
    class SurpriseState:
    ema_thresh: float = 2.5 # initial nll threshold

    def per_token_nll(lo­gits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
    # logits: [B, T, V], labels: [B, T]; compute per-token NLL
    B, T, V = logits.shape
    logits = logits[:, :-1, :].contiguous()
    labels = labels[:, 1:].contiguous()
    nll = F.cross_entropy(
    logits.view(-1, V),
    labels.view(-1),
    reduction="none"
    ).view(B, T-1)
    return nll

    def top_p_mask(nll: torch.Tensor, top_p: float, min_tokens: int) -> torch.Tensor:
    # Select top_p fraction of highest NLL tokens across the batch
    B, Tm1 = nll.shape
    flat = nll.view(-1)
    k = max(min_tokens, int(flat.numel() * max(1e-4, min(top_p, 1.0))))
    topk = torch.topk(flat, k=k, sorted=False)­.indices
    mask = torch.zeros_li­ke(flat, dtype=torch.bool)
    mask[topk] = True
    return mask.view(B, Tm1)

    def ema_threshold_mas­k(nll: torch.Tensor, state: SurpriseState, tau: float, min_tokens: int) -> torch.Tensor:
    # Keep tokens with NLL > ema threshold; update EMA to track running difficulty
    with torch.no_grad():
    th = state.ema_thresh
    sel = nll > th
    # Ensure at least min_tokens
    if sel.sum().item() < min_tokens:
    # fallback: take highest NLL tokens to reach min_tokens
    add_mask = top_p_mask(nll, top_p=min_tokens / nll.numel(), min_tokens=min_to­kens)
    sel = sel | add_mask
    # Update EMA threshold toward e.g. 70th percentile
    target = torch.quantile(nll­.detach().flo­at().view(-1), 0.7).item()
    state.ema_thresh = tau * state.ema_thresh + (1 - tau) * target
    return sel

    def masked_mean_los­s(nll: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
    # Average only over selected tokens
    sel = mask
    denom = sel.sum()
    if denom == 0:
    return nll.mean()
    return (nll * sel).sum() / denom

    # --------------------- Bandit (UCB over shards) ---------------------

    class UCB:
    def __init__(self, n_arms: int, c: float = 1.2):
    import numpy as np
    self.np = __import__("numpy")
    self.n = n_arms
    self.c = c
    self.counts = self.np.zeros(n_ar­ms, dtype=self.np­.int64)
    self.values = self.np.zeros(n_ar­ms, dtype=self.np­.float32)

    def select(self, step: int) -> int:
    total = max(1, self.counts.sum())
    ucb = self.values + self.c * self.np.sqrt(sel­f.np.log(total) / (self.counts + 1e-9))
    return int(ucb.argmax())

    def update(self, arm: int, reward: float):
    self.counts[arm] += 1
    beta = 0.1
    self.values[arm] = (1 - beta) * self.values[arm] + beta * reward

    # --------------------- Validation ---------------------

    def eval_ppl(model, dataloader, accelerator) -> float:
    model.eval()
    device = accelerator.device
    nll_total, tok_total = 0.0, 0
    with torch.no_grad():
    for batch in dataloader:
    batch = {k: v.to(device) for k,v in batch.items()}
    out = model(**batch, labels=batch["in­put_ids"])
    nll_total += out.loss.item() * batch["input_id­s"].numel()
    tok_total += batch["input_id­s"].numel()
    model.train()
    world = accelerator.num_pro­cesses
    # average across processes
    nll_total = accelerator.gat­her_for_metric­s(torch.tensor([nll_to­tal], device=device))­.sum().item()
    tok_total = accelerator.gat­her_for_metric­s(torch.tensor([tok_to­tal], device=device))­.sum().item()
    return math.exp(nll_total / max(1, tok_total))

    # --------------------- Main ---------------------

    def main():
    args = parse_args()
    set_seed(1337)
    accelerator = Accelerator(mi­xed_precision="bf16")
    device = accelerator.device

    # Tokenizer
    tokenizer = AutoTokenizer­.from_pretrai­ned(args.model_na­me, use_fast=True)
    if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

    # Config + FlashAttention hint
    cfg = AutoConfig.from_pre­trained(args.mo­del_name)
    cfg = maybe_enable_flash_at­tention(cfg, args.flash_attn)

    # Model
    quant_kwargs = {}
    if args.use_4bit:
    quant_kwargs = {"load_in_4bit": True, "bnb_4bit_com­pute_dtype": torch.bfloat16}
    model = AutoModelForCau­salLM.from_pre­trained(
    args.model_name, config=cfg, torch_dtype=tor­ch.bfloat16, **quant_kwargs
    )

    # Attempt to set FA2 on model if supported
    if args.flash_attn:
    try:
    if hasattr(model­.config, "attn_implemen­tation"):
    model.config.at­tn_implementa­tion = "flash_attention_2"
    print("[INFO] FlashAttention v2 hint applied.")
    except Exception as e:
    print(f"[WARN] FlashAttention hint failed: {e}")

    # PEFT
    model = attach_peft_if_ne­eded(model, method=args.peft, rank=args.lora_r, use_4bit=args­.use_4bit)

    # Data
    train_loader, val_loader = make_loaders(to­kenizer, args.dataset, args.subset, args.seq_len, args.bsz)

    # Simple sharding for UCB
    n_shards = max(1, args.shards)
    shards = [list(range(i, len(train_loader), n_shards)) for i in range(n_shards)]
    ucb = UCB(n_shards) if args.use_ucb else None

    optimizer = torch.optim.A­damW(model.pa­rameters(), lr=args.lr, weight_decay=ar­gs.weight_deca­y)
    scheduler = get_cosine_sche­dule_with_war­mup(optimizer, args.warmup, args.steps)

    model, optimizer, train_loader, val_loader, scheduler = accelerator.pre­pare(
    model, optimizer, train_loader, val_loader, scheduler
    )

    surprise_state = SurpriseState()
    model.train()
    step = 0
    pbar = tqdm(total=ar­gs.steps, disable=not accelerator.is_lo­cal_main_proces­s)

    # To index DataLoader deterministically per "step", we materialize it once.
    train_batches = list(train_loader)

    while step < args.steps:
    # Bandit shard selection
    if ucb is not None:
    shard_idx = ucb.select(step)
    shard_indices = shards[shard_idx]
    batch = train_batches[ shard_indices[ step % len(shard_indices) ] ]
    else:
    batch = train_batches[ step % len(train_batches) ]

    batch = {k: v.to(device) for k,v in batch.items()}
    labels = batch["input_ids"]

    # Forward
    outputs = model(**batch, labels=None) # we will compute our own per-token loss
    nll = per_token_nll(ou­tputs.logits, labels) # [B, T-1]

    # Surprise selection
    if args.ema_threshold:
    mask = ema_threshold_mas­k(nll, surprise_state, tau=args.ema_tau, min_tokens=ar­gs.min_tokens)
    else:
    mask = top_p_mask(nll, top_p=args.sur­prise_top_p, min_tokens=ar­gs.min_tokens)

    loss = masked_mean_los­s(nll, mask)

    accelerator.bac­kward(loss)
    optimizer.step(); scheduler.step(); optimizer.zero_grad()

    # Bandit reward: use negative masked loss (lower is better)
    if ucb is not None:
    rew = float(-accelerator.gat­her_for_metric­s(loss.detach())­.mean().item())
    ucb.update(shar­d_idx, rew)

    if step % args.val_every == 0 and step > 0 and accelerator.is_lo­cal_main_proces­s:
    ppl = eval_ppl(model, val_loader, accelerator)
    print(f"[step {step}] ppl={ppl:.2f}")
    save_path = os.path.join(ar­gs.save_dir, f"ckpt_{step}.pt")
    os.makedirs(ar­gs.save_dir, exist_ok=True)
    torch.save({"mo­del": accelerator.un­wrap_model(mo­del).state_dic­t(), "step": step}, save_path)

    step += 1
    pbar.update(1)

    if accelerator.is_lo­cal_main_proces­s:
    os.makedirs(ar­gs.save_dir, exist_ok=True)
    torch.save({"mo­del": accelerator.un­wrap_model(mo­del).state_dic­t(), "step": step},
    os.path.join(ar­gs.save_dir, f"ckpt_{step}.pt"))
    print("Done.")

    if __name__ == "__main__":
    from accelerate import Accelerator
    main()

Upozorníme vás na články, které by vám neměly uniknout (maximálně 2x týdně).