So, is "Tush-aa-r" (with two "a"s) is more surprising than "Tush-a-r" (with one "a")?
In character-level (= byte-level for En) language modeling, the model's suprise is measured in bits-per-byte (BPB): the lower the BPB, the less surprised the model is upon seeing the text (think: "hant" follows "elep"). Out of scientific curiosity, I used the EvaByte-6.5B model (trained on 1.5B tokens) to compute the BPB of "ushaar" (two "a"s) vs. "ushar" (single "a"), both starting from "T".
# !pip install bitsandbytes
import torch as t
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM
import matplotlib.pyplot as plt
import seaborn as sns
# Load the EvaByte-6.5B token-level tokenizer and model (trained on 1.5B tokens)
# in 8-bit from HuggingFace.
# See: https://huggingface.co/EvaByte/EvaByte.
model_name = "evabyte/EvaByte"
tokenizer = AutoTokenizer.from_pretrained(model_name,trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
return_dict_in_generate=True,
output_attentions=True,
trust_remote_code=True,
)
model.eval()
# Encode "Tushaar" and "Tushar" as bytes (= characters) and process them.
for name in ["Tushaar", "Tushar"]:
# Tokenization + extract input IDs and position IDs.
inputs = tokenizer(name, return_tensors="pt")
iids = inputs["input_ids"].cuda()
pids = t.arange(0, iids.shape[1]).expand_as(iids).cuda()
with t.no_grad():
# outputs.logits: (b=1, L, V=320)
outputs = model(
iids,
position_ids=pids,
return_dict=True,
output_attentions=False, # Needed for EvaByte
output_hidden_states=True,
)
log_probs = F.log_softmax(outputs.logits, dim=-1)
# Ignore the `<bos>` token and start as if we were generating from "T".
nll = F.nll_loss(
input=log_probs[0, 1:-1, :].reshape(-1, outputs.logits.shape[-1]),
target=iids[0, 2:].reshape(-1),
)
bpb = (nll / t.log(t.tensor(2.0))).item() # bits per byte
print(f"{name}: {bpb:.2f} bits/byte")
# >>> Tushaar: 4.40 bits/byte
# >>> Tushar: 3.58 bits/byte
Okay, as it turns out, starting from "T", the BPB of "ushaar" (two "a"s) is 4.40 bits/byte, while the BPB of "ushar" (single "a") is 3.58 bits/byte. Conclusively, [EvaByte thinks that] "Tush-aa-r" really is more surprising than "Tush-a-r"!!—double the "a"s, double the "!!?"!
The "surprise," well, is unsurprising!
For fun, we can plot the causal (= left-to-right, only accessing the previous tokens), normalized self-similarity softmax matrix using the hidden states at each character in "Tushaar". You can think of this as capturing the essence of the attention matrix in purely-attention-based models (like GPT-2).
from jaxtyping import Float
Fl = lambda size: Float[t.Tensor, size]
def sim_h(h: Fl("b L d") | Fl("L d")) -> Fl("L L"):
if h.ndim == 3 and h.shape[0] == 1:
h = h.squeeze(0)
h /= h.norm(dim=-1, keepdim=True)
return F.softmax(
h@h.T.tril(), dim=-1
).tril().cpu().numpy()
# Plot the self-similarity matrix of last layer.
sim_mat = sim_h(outputs.hidden_states[-1])
sim_mat = sim_mat[1:, 1:] # ignore `<bos>`
fig, ax = plt.subplots(figsize=(6, 6))
toks = [tokenizer.decode(id) for id in iids[0, 1:]]
sns.heatmap(
sim_mat,
cmap="rocket",
cbar=True,
xticklabels=toks,
yticklabels=toks,
annot=True,
)
plt.yticks(rotation=0)
plt.show()
... which, coincidentally, happens to be the favicon of this website!
It is interesting to spot the clear(-ish) divide between "Tu" and "shaar" in the self-similarity heatmap—and funny enough, this matches how my name is actually pronounced: "Thuh-shaar"!! Wow, it's almost like language and speech are connected or something (*shocking*, I know!).