So, is "Tush-aa-r" (with two "a"s) is more surprising than "Tush-a-r" (with one "a")?

In character-level (= byte-level for En) language modeling, the model's suprise is measured in bits-per-byte (BPB): the lower the BPB, the less surprised the model is upon seeing the text (think: "hant" follows "elep"). Out of scientific curiosity, I used the EvaByte-6.5B model (trained on 1.5B tokens) to compute the BPB of "ushaar" (two "a"s) vs. "ushar" (single "a"), both starting from "T".

    
# !pip install bitsandbytes

import torch as t
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM

import matplotlib.pyplot as plt
import seaborn as sns


# Load the EvaByte-6.5B token-level tokenizer and model (trained on 1.5B tokens) 
# in 8-bit from HuggingFace.
# See: https://huggingface.co/EvaByte/EvaByte.
model_name = "evabyte/EvaByte"
tokenizer = AutoTokenizer.from_pretrained(model_name,trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    return_dict_in_generate=True,
    output_attentions=True,
    trust_remote_code=True,
)
model.eval()

# Encode "Tushaar" and "Tushar" as bytes (= characters) and process them.
for name in ["Tushaar", "Tushar"]:
    # Tokenization + extract input IDs and position IDs.
    inputs = tokenizer(name, return_tensors="pt")
    iids = inputs["input_ids"].cuda()
    pids = t.arange(0, iids.shape[1]).expand_as(iids).cuda()

    with t.no_grad():
        # outputs.logits: (b=1, L, V=320)
        outputs = model(
            iids, 
            position_ids=pids, 
            return_dict=True, 
            output_attentions=False,  # Needed for EvaByte
            output_hidden_states=True,
        )

    log_probs = F.log_softmax(outputs.logits, dim=-1)
    # Ignore the `<bos>` token and start as if we were generating from "T".
    nll = F.nll_loss(
        input=log_probs[0, 1:-1, :].reshape(-1, outputs.logits.shape[-1]), 
        target=iids[0, 2:].reshape(-1),
    )
    bpb = (nll / t.log(t.tensor(2.0))).item()  # bits per byte
    print(f"{name}: {bpb:.2f} bits/byte")

# >>> Tushaar: 4.40 bits/byte
# >>> Tushar: 3.58 bits/byte
    
    

Okay, as it turns out, starting from "T", the BPB of "ushaar" (two "a"s) is 4.40 bits/byte, while the BPB of "ushar" (single "a") is 3.58 bits/byte. Conclusively, [EvaByte thinks that] "Tush-aa-r" really is more surprising than "Tush-a-r"!! Double the "a"s, double the "!!?"!

The "surprise," well, is unsurprising!


For fun, we can plot the causal (= left-to-right, only accessing the previous tokens), normalized self-similarity softmax matrix using the hidden states at each character in "Tushaar". You can think of this as capturing the essence of the attention matrix in purely-attention-based models (like GPT-2).

    
from jaxtyping import Float

Fl = lambda size: Float[t.Tensor, size]


def sim_h(h: Fl("b L d") | Fl("L d")) \
    -> Fl("L L"):
    if h.ndim == 3 and h.shape[0] == 1:
        h = h.squeeze(0)
    h /= h.norm(dim=-1, keepdim=True)
    return F.softmax(
        h@h.T.tril(), dim=-1
    ).tril().cpu().numpy()

# Plot the self-similarity matrix of last layer.
sim_mat = sim_h(outputs.hidden_states[-1])
sim_mat = sim_mat[1:, 1:]  # ignore `<bos>`

fig, ax = plt.subplots(figsize=(6, 6))
toks = [tokenizer.decode(id) 
        for id in iids[0, 1:]]
sns.heatmap(
    sim_mat,
    cmap="rocket",
    cbar=True,
    xticklabels=toks, 
    yticklabels=toks, 
    annot=True,
)
plt.yticks(rotation=0) 
plt.show()
    
    
Self-similarity matrix of "Tushaar" (with two "a"s)
Causal self-similarity matrix of "Tushaar" (with two "a"s).
[Output from running the code to the left.]

... which, coincidentally, happens to be the favicon of this website!

It is interesting to spot the clear(-ish) divide between "Tu" and "shaar" in the self-similarity heatmap, and funny enough, this matches how my name is actually pronounced: "Thuh-shaar"!! Wow, it's almost like language and speech are connected or something (*shocking*, I know!).


Back to Tushaar's homepage
Research on byte-level language modeling