---
title: "Using a GPT2 transformer model to get word predictability"
bibliography: '`r system.file("REFERENCES.bib", package="pangoling")`'
vignette: >
%\VignetteIndexEntry{Using a GPT2 transformer model to get word predictability}
%\VignetteEngine{knitr::rmarkdown}
%\VignetteEncoding{UTF-8}
---
Transformer models are a type of neural network architecture used for natural
language processing tasks such as language translation and text generation. They
were introduced in the @vaswani2017attention paper "Attention Is All You Need".
Large Language Models (LLMs) are a specific type of pre-trained transformer
models. These models have been trained on massive amounts of text data and can
be fine-tuned to perform a variety of NLP tasks such as text classification,
named entity recognition, question answering, etc.
A causal language model (also called GPT-like, auto-regressive, or decoder
model) is a type of large language model usually used for text-generation that
can predict the next word (or more accurately, the next token) based on a
preceding context. GPT-2 (Generative Pre-trained Transformer 2) developed by OpenAI is an example of a causal language model
[see also @radford2019language].
One interesting side-effect of causal language models is that the (log)
probability of a word given a certain context can be extracted from the models.
Load the following packages first:
``` r
library(pangoling)
library(tidytable) # fast alternative to dplyr
library(tictoc) # measure time
```
Then let's examine which continuation GPT-2 predicts following a specific
context. [Hugging Face](https://huggingface.co/) provide access to pre-trained
models, including freely available versions of different sizes of
[GPT-2](https://huggingface.co/gpt2). The function
`causal_next_tokens_pred_tbl()` will, by default, use the smallest version of
[GPT-2](https://huggingface.co/gpt2), but this can be modified with the argument
`model`.
Let's see what GPT-2 predicts following "The apple doesn't fall far from the".
``` r
tic()
(df_pred <- causal_next_tokens_pred_tbl("The apple doesn't fall far from the"))
#> Processing using causal model 'gpt2/' ...
#> # A tidytable: 50,257 × 2
#> token pred
#>
#> 1 Ġtree -0.281
#> 2 Ġtrees -3.60
#> 3 Ġapple -4.29
#> 4 Ġtable -4.50
#> 5 Ġhead -4.83
#> 6 Ġmark -4.86
#> 7 Ġcake -4.91
#> 8 Ġground -5.08
#> 9 Ġtruth -5.31
#> 10 Ġtop -5.36
#> # ℹ 50,247 more rows
toc()
#> 5.438 sec elapsed
```
(The pretrained models and tokenizers will be downloaded from
https://huggingface.co/ the first time they are used.)
The most likely continuation is "tree", which makes sense.
The first time a model is run, it will download some files that will be
available for subsequent runs. However, every time we start a new R session and
we run a model, it will take some time to store it in memory. Next runs in the
same session are much faster. We can also preload a model with
`causal_preload()`.
``` r
tic()
(df_pred <- causal_next_tokens_pred_tbl("The apple doesn't fall far from the"))
#> Processing using causal model 'gpt2/' ...
#> # A tidytable: 50,257 × 2
#> token pred
#>
#> 1 Ġtree -0.281
#> 2 Ġtrees -3.60
#> 3 Ġapple -4.29
#> 4 Ġtable -4.50
#> 5 Ġhead -4.83
#> 6 Ġmark -4.86
#> 7 Ġcake -4.91
#> 8 Ġground -5.08
#> 9 Ġtruth -5.31
#> 10 Ġtop -5.36
#> # ℹ 50,247 more rows
toc()
#> 0.773 sec elapsed
```
Notice that the tokens--that is, the way GPT-2 interprets words-- that are
predicted start with `Ġ`, this indicates that they are not the first word of a
sentence.
In fact this is the way GPT-2 interprets our context:
``` r
tokenize_lst("The apple doesn't fall far from the")
#> [[1]]
#> [1] "The" "Ġapple" "Ġdoesn" "'t" "Ġfall" "Ġfar" "Ġfrom" "Ġthe"
```
Also notice that GPT-2 tokenizer interprets differently initial tokens from
tokens that follow a space. A space in a token is indicated with "Ġ".
``` r
tokenize_lst("This is different from This")
#> [[1]]
#> [1] "This" "Ġis" "Ġdifferent" "Ġfrom" "ĠThis"
```
It's also possible to decode the tokens to get "pure" text:
``` r
tokenize_lst("This is different from This", decode = TRUE)
#> [[1]]
#> [1] "This" " is" " different" " from" " This"
```
Going back to the initial example, because `causal_next_tokens_pred_tbl()`
returns by default log natural probabilities, if we exponentiate them and we
sum them, we should get 1:
``` r
sum(exp(df_pred$pred))
#> [1] 1.000017
```
Because of approximation errors, this is not exactly one.
When doing tests,
[`sshleifer/tiny-gpt2`](https://huggingface.co/sshleifer/tiny-gpt2) is quite
useful since it's much faster because it's a tiny model. But notice that the
predictions are quite bad.
``` r
causal_preload("sshleifer/tiny-gpt2")
#> Preloading causal model sshleifer/tiny-gpt2...
tic()
causal_next_tokens_pred_tbl("The apple doesn't fall far from the",
model = "sshleifer/tiny-gpt2"
)
#> Processing using causal model 'sshleifer/tiny-gpt2/' ...
#> # A tidytable: 50,257 × 2
#> token pred
#>
#> 1 Ġstairs -10.7
#> 2 Ġvendors -10.7
#> 3 Ġintermittent -10.7
#> 4 Ġhauled -10.7
#> 5 ĠBrew -10.7
#> 6 Rocket -10.7
#> 7 dit -10.7
#> 8 ĠHabit -10.7
#> 9 ĠJr -10.7
#> 10 ĠRh -10.7
#> # ℹ 50,247 more rows
toc()
#> 0.095 sec elapsed
```
All in all, the package `pangoling` would be most useful in the following
situation. (And see also the [worked-out example vignette](example.html).)
Given a (toy) dataset where sentences are organized with one word or short
phrase in each row:
``` r
sentences <- c(
"The apple doesn't fall far from the tree.",
"Don't judge a book by its cover."
)
df_sent <- strsplit(x = sentences, split = " ") |>
map_dfr(.f = ~ data.frame(word = .x), .id = "sent_n")
df_sent
#> # A tidytable: 15 × 2
#> sent_n word
#>
#> 1 1 The
#> 2 1 apple
#> 3 1 doesn't
#> 4 1 fall
#> 5 1 far
#> 6 1 from
#> 7 1 the
#> 8 1 tree.
#> 9 2 Don't
#> 10 2 judge
#> 11 2 a
#> 12 2 book
#> 13 2 by
#> 14 2 its
#> 15 2 cover.
```
One can get the natural log-transformed probability of each word based on
GPT-2 as follows:
``` r
df_sent <- df_sent |>
mutate(lp = causal_words_pred(word, by = sent_n))
#> Processing using causal model 'gpt2/' ...
#> Processing a batch of size 1 with 10 tokens.
#> Processing a batch of size 1 with 9 tokens.
#> Text id: 1
#> `The apple doesn't fall far from the tree.`
#> Text id: 2
#> `Don't judge a book by its cover.`
#> ***
df_sent
#> # A tidytable: 15 × 3
#> sent_n word lp
#>
#> 1 1 The NA
#> 2 1 apple -10.9
#> 3 1 doesn't -5.50
#> 4 1 fall -3.60
#> 5 1 far -2.91
#> 6 1 from -0.745
#> 7 1 the -0.207
#> 8 1 tree. -1.58
#> 9 2 Don't NA
#> 10 2 judge -6.27
#> 11 2 a -2.33
#> 12 2 book -1.97
#> 13 2 by -0.409
#> 14 2 its -0.257
#> 15 2 cover. -1.38
```
Notice that the `by` is inside the `causal_words_pred()` function. It's also
possible to use `by` in the mutate call, or `group_by()`, but it will be slower.
The attentive reader might have noticed that the log-probability of "tree" here
is not the same as the one presented before. This is because the actual word is
`" tree."` (notice the space), which contains two tokens:
``` r
tokenize_lst(" tree.")
#> [[1]]
#> [1] "Ġtree" "."
```
The log-probability of `" tree."` is the sum of the log-probability of `" tree"`
given its context and `"."` given its context.
We can verify this in the following way.
``` r
df_token_lp <- causal_tokens_pred_lst(
"The apple doesn't fall far from the tree.") |>
# convert the list into a data frame
map_dfr(~ data.frame(token = names(.x), pred = .x))
#> Processing using causal model 'gpt2/' ...
#> Processing a batch of size 1 with 10 tokens.
df_token_lp
#> # A tidytable: 10 × 2
#> token pred
#>
#> 1 The NA
#> 2 Ġapple -10.9
#> 3 Ġdoesn -5.50
#> 4 't -0.000828
#> 5 Ġfall -3.60
#> 6 Ġfar -2.91
#> 7 Ġfrom -0.745
#> 8 Ġthe -0.207
#> 9 Ġtree -0.281
#> 10 . -1.30
(tree_lp <- df_token_lp |>
# requires a Ġ because there is a space before
filter(token == "Ġtree") |>
pull())
#> [1] -0.2808024
(dot_lp <- df_token_lp |>
# doesn't require a Ġ because there is no space before
filter(token == ".") |>
pull())
#> [1] -1.300929
tree._lp <- df_sent |>
filter(word == "tree.") |>
pull()
# Test whether it is equal
all.equal(
tree_lp + dot_lp,
tree._lp
)
#> [1] TRUE
```
In a scenario as the one above, when one has a word-by-word text, and one wants
to know the log-probability of each word, one doesn't have to worry about the
encoding or tokens, since the function `causal_words_pred()` takes care of it.
# References