View on GitHub
GitHub
Neural networks
Attention in transformers, step-by-step | Deep Learning Chapter 6
Loading player
Notes
Transcript
423 segments
0:00
In the last chapter, you and I started to step
0:01
through the internal workings of a transformer.
0:04
This is one of the key pieces of technology inside large language models,
0:07
and a lot of other tools in the modern wave of AI.
0:10
It first hit the scene in a now-famous 2017 paper called Attention is All You Need,
0:15
and in this chapter you and I will dig into what this attention mechanism is,
0:19
visualizing how it processes data.
0:26
As a quick recap, here's the important context I want you to have in mind.
0:30
The goal of the model that you and I are studying is to
0:32
take in a piece of text and predict what word comes next.
0:36
The input text is broken up into little pieces that we call tokens,
0:40
and these are very often words or pieces of words,
0:42
but just to make the examples in this video easier for you and me to think about,
0:47
let's simplify by pretending that tokens are always just words.
0:51
The first step in a transformer is to associate each token
0:54
with a high-dimensional vector, what we call its embedding.
0:57
The most important idea I want you to have in mind is how directions in this
1:02
high-dimensional space of all possible embeddings can correspond with semantic meaning.
1:07
In the last chapter we saw an example for how direction can correspond to gender,
1:11
in the sense that adding a certain step in this space can take you from the
1:15
embedding of a masculine noun to the embedding of the corresponding feminine noun.
1:20
That's just one example you could imagine how many other directions in this
1:23
high-dimensional space could correspond to numerous other aspects of a word's meaning.
1:28
The aim of a transformer is to progressively adjust these embeddings
1:32
so that they don't merely encode an individual word,
1:35
but instead they bake in some much, much richer contextual meaning.
1:40
I should say up front that a lot of people find the attention mechanism,
1:43
this key piece in a transformer, very confusing,
1:46
so don't worry if it takes some time for things to sink in.
1:49
I think that before we dive into the computational details and all
1:52
the matrix multiplications, it's worth thinking about a couple
1:55
examples for the kind of behavior that we want attention to enable.
2:00
Consider the phrases American shrew mole, one mole of carbon dioxide,
2:04
and take a biopsy of the mole.
2:06
You and I know that the word mole has different meanings in each one of these,
2:09
based on the context.
2:11
But after the first step of a transformer, the one that breaks up the text
2:15
and associates each token with a vector, the vector that's associated with
2:18
mole would be the same in all of these cases,
2:21
because this initial token embedding is effectively a lookup table with no
2:24
reference to the context.
2:26
It's only in the next step of the transformer that the surrounding
2:29
embeddings have the chance to pass information into this one.
2:33
The picture you might have in mind is that there are multiple distinct directions in
2:38
this embedding space encoding the multiple distinct meanings of the word mole,
2:42
and that a well-trained attention block calculates what you need to add to the generic
2:47
embedding to move it to one of these specific directions, as a function of the context.
2:53
To take another example, consider the embedding of the word tower.
2:57
This is presumably some very generic, non-specific direction in the space,
3:01
associated with lots of other large, tall nouns.
3:04
If this word was immediately preceded by Eiffel,
3:06
you could imagine wanting the mechanism to update this vector so that
3:10
it points in a direction that more specifically encodes the Eiffel tower,
3:14
maybe correlated with vectors associated with Paris and France and things made of steel.
3:19
If it was also preceded by the word miniature,
3:22
then the vector should be updated even further,
3:24
so that it no longer correlates with large, tall things.
3:29
More generally than just refining the meaning of a word,
3:32
the attention block allows the model to move information encoded in
3:35
one embedding to that of another, potentially ones that are quite far away,
3:39
and potentially with information that's much richer than just a single word.
3:43
What we saw in the last chapter was how after all of the vectors flow through the
3:47
network, including many different attention blocks,
3:50
the computation you perform to produce a prediction of the next token is entirely a
3:55
function of the last vector in the sequence.
3:59
Imagine, for example, that the text you input is most of an entire mystery novel,
4:03
all the way up to a point near the end, which reads, therefore the murderer was.
4:08
If the model is going to accurately predict the next word,
4:11
that final vector in the sequence, which began its life simply embedding the word was,
4:16
will have to have been updated by all of the attention blocks to represent much,
4:20
much more than any individual word, somehow encoding all of the information
4:24
from the full context window that's relevant to predicting the next word.
4:29
To step through the computations, though, let's take a much simpler example.
4:32
Imagine that the input includes the phrase, a
4:35
fluffy blue creature roamed the verdant forest.
4:38
And for the moment, suppose that the only type of update that we care about
4:42
is having the adjectives adjust the meanings of their corresponding nouns.
4:47
What I'm about to describe is what we would call a single head of attention,
4:50
and later we will see how the attention block consists of many different heads run in
4:54
parallel.
4:56
Again, the initial embedding for each word is some high dimensional vector
4:59
that only encodes the meaning of that particular word with no context.
5:04
Actually, that's not quite true.
5:05
They also encode the position of the word.
5:07
There's a lot more to say about the specific way that positions are encoded,
5:11
but right now, all you need to know is that the entries of this vector are
5:15
enough to tell you both what the word is and where it exists in the context.
5:19
Let's go ahead and denote these embeddings with the letter e.
5:22
The goal is to have a series of computations produce a new refined
5:26
set of embeddings where, for example, those corresponding to the
5:29
nouns have ingested the meaning from their corresponding adjectives.
5:33
And playing the deep learning game, we want most of the computations
5:37
involved to look like matrix-vector products,
5:39
where the matrices are full of tuneable weights,
5:41
things that the model will learn based on data.
5:44
To be clear, I'm making up this example of adjectives updating nouns just to
5:48
illustrate the type of behavior that you could imagine an attention head doing.
5:52
As with so much deep learning, the true behavior is much harder to parse because it's
5:57
based on tweaking and tuning a huge number of parameters to minimize some cost function.
6:01
It's just that as we step through all of different matrices filled with parameters
6:05
that are involved in this process, I think it's really helpful to have an imagined
6:09
example of something that it could be doing to help keep it all more concrete.
6:14
For the first step of this process, you might imagine each noun, like creature,
6:18
asking the question, hey, are there any adjectives sitting in front of me?
6:22
And for the words fluffy and blue, to each be able to answer,
6:25
yeah, I'm an adjective and I'm in that position.
6:28
That question is somehow encoded as yet another vector,
6:32
another list of numbers, which we call the query for this word.
6:36
This query vector though has a much smaller dimension than the embedding vector, say 128.
6:42
Computing this query looks like taking a certain matrix,
6:46
which I'll label wq, and multiplying it by the embedding.
6:50
Compressing things a bit, let's write that query vector as q,
6:54
and then anytime you see me put a matrix next to an arrow like this one,
6:58
it's meant to represent that multiplying this matrix by the vector at the arrow's start
7:02
gives you the vector at the arrow's end.
7:05
In this case, you multiply this matrix by all of the embeddings in the context,
7:10
producing one query vector for each token.
7:13
The entries of this matrix are parameters of the model,
7:16
which means the true behavior is learned from data, and in practice,
7:19
what this matrix does in a particular attention head is challenging to parse.
7:23
But for our sake, imagining an example that we might hope that it would learn,
7:27
we'll suppose that this query matrix maps the embeddings of nouns to
7:31
certain directions in this smaller query space that somehow encodes
7:34
the notion of looking for adjectives in preceding positions.
7:38
As to what it does to other embeddings, who knows?
7:41
Maybe it simultaneously tries to accomplish some other goal with those.
7:44
Right now, we're laser focused on the nouns.
7:47
At the same time, associated with this is a second matrix called the key matrix,
7:51
which you also multiply by every one of the embeddings.
7:55
This produces a second sequence of vectors that we call the keys.
7:59
Conceptually, you want to think of the keys as potentially answering the queries.
8:03
This key matrix is also full of tuneable parameters, and just like the query matrix,
8:07
it maps the embedding vectors to that same smaller dimensional space.
8:12
You think of the keys as matching the queries whenever they closely align with each other.
8:17
In our example, you would imagine that the key matrix maps the adjectives like fluffy and
8:22
blue to vectors that are closely aligned with the query produced by the word creature.
8:27
To measure how well each key matches each query,
8:30
you compute a dot product between each possible key-query pair.
8:34
I like to visualize a grid full of a bunch of dots,
8:37
where the bigger dots correspond to the larger dot products,
8:40
the places where the keys and queries align.
8:43
For our adjective noun example, that would look a little more like this,
8:47
where if the keys produced by fluffy and blue really do align closely with the query
8:52
produced by creature, then the dot products in these two spots would be some large
8:57
positive numbers.
8:59
In the lingo, machine learning people would say that this means the
9:02
embeddings of fluffy and blue attend to the embedding of creature.
9:06
By contrast to the dot product between the key for some other
9:09
word like the and the query for creature would be some small
9:12
or negative value that reflects that are unrelated to each other.
9:17
So we have this grid of values that can be any real number from
9:21
negative infinity to infinity, giving us a score for how relevant
9:25
each word is to updating the meaning of every other word.
9:29
The way we're about to use these scores is to take a certain
9:32
weighted sum along each column, weighted by the relevance.
9:36
So instead of having values range from negative infinity to infinity,
9:40
what we want is for the numbers in these columns to be between 0 and 1,
9:43
and for each column to add up to 1, as if they were a probability distribution.
9:49
If you're coming in from the last chapter, you know what we need to do then.
9:52
We compute a softmax along each one of these columns to normalize the values.
10:00
In our picture, after you apply softmax to all of the columns,
10:03
we'll fill in the grid with these normalized values.
10:06
At this point you're safe to think about each column as giving weights according
10:10
to how relevant the word on the left is to the corresponding value at the top.
10:15
We call this grid an attention pattern.
10:18
Now if you look at the original transformer paper,
10:20
there's a really compact way that they write this all down.
10:23
Here the variables q and k represent the full arrays of query
10:27
and key vectors respectively, those little vectors you get by
10:31
multiplying the embeddings by the query and the key matrices.
10:35
This expression up in the numerator is a really compact way to represent
10:39
the grid of all possible dot products between pairs of keys and queries.
10:44
A small technical detail that I didn't mention is that for numerical stability,
10:48
it happens to be helpful to divide all of these values by the
10:51
square root of the dimension in that key query space.
10:54
Then this softmax that's wrapped around the full expression
10:57
is meant to be understood to apply column by column.
11:01
As to that v term, we'll talk about it in just a second.
11:05
Before that, there's one other technical detail that so far I've skipped.
11:09
During the training process, when you run this model on a given text example,
11:12
and all of the weights are slightly adjusted and tuned to either reward or punish it
11:17
based on how high a probability it assigns to the true next word in the passage,
11:21
it turns out to make the whole training process a lot more efficient if you
11:25
simultaneously have it predict every possible next token following each initial
11:29
subsequence of tokens in this passage.
11:31
For example, with the phrase that we've been focusing on,
11:34
it might also be predicting what words follow creature and what words follow the.
11:39
This is really nice, because it means what would otherwise
11:42
be a single training example effectively acts as many.
11:46
For the purposes of our attention pattern, it means that you never
11:49
want to allow later words to influence earlier words,
11:52
since otherwise they could kind of give away the answer for what comes next.
11:56
What this means is that we want all of these spots here,
11:59
the ones representing later tokens influencing earlier ones,
12:02
to somehow be forced to be zero.
12:05
The simplest thing you might think to do is to set them equal to zero,
12:08
but if you did that the columns wouldn't add up to one anymore,
12:11
they wouldn't be normalized.
12:13
So instead, a common way to do this is that before applying softmax,
12:16
you set all of those entries to be negative infinity.
12:19
If you do that, then after applying softmax, all of those get turned into zero,
12:23
but the columns stay normalized.
12:26
This process is called masking.
12:27
There are versions of attention where you don't apply it, but in our GPT example,
12:31
even though this is more relevant during the training phase than it would be,
12:34
say, running it as a chatbot or something like that,
12:37
you do always apply this masking to prevent later tokens from influencing earlier ones.
12:42
Another fact that's worth reflecting on about this attention
12:45
pattern is how its size is equal to the square of the context size.
12:49
So this is why context size can be a really huge bottleneck for large language models,
12:53
and scaling it up is non-trivial.
12:56
As you imagine, motivated by a desire for bigger and bigger context windows,
13:00
recent years have seen some variations to the attention mechanism aimed at making
13:04
context more scalable, but right here, you and I are staying focused on the basics.
13:10
Okay, great, computing this pattern lets the model
13:12
deduce which words are relevant to which other words.
13:16
Now you need to actually update the embeddings,
13:18
allowing words to pass information to whichever other words they're relevant to.
13:22
For example, you want the embedding of Fluffy to somehow cause a change
13:26
to Creature that moves it to a different part of this 12,000-dimensional
13:30
embedding space that more specifically encodes a Fluffy creature.
13:35
What I'm going to do here is first show you the most straightforward
13:38
way that you could do this, though there's a slight way that
13:40
this gets modified in the context of multi-headed attention.
13:44
This most straightforward way would be to use a third matrix,
13:47
what we call the value matrix, which you multiply by the embedding of that first word,
13:51
for example Fluffy.
13:53
The result of this is what you would call a value vector,
13:55
and this is something that you add to the embedding of the second word,
13:59
in this case something you add to the embedding of Creature.
14:02
So this value vector lives in the same very high-dimensional space as the embeddings.
14:07
When you multiply this value matrix by the embedding of a word,
14:10
you might think of it as saying, if this word is relevant to adjusting the meaning of
14:15
something else, what exactly should be added to the embedding of that something else
14:19
in order to reflect this?
14:22
Looking back in our diagram, let's set aside all of the keys and the queries,
14:25
since after you compute the attention pattern you're done with those,
14:29
then you're going to take this value matrix and multiply it by every
14:32
one of those embeddings to produce a sequence of value vectors.
14:37
You might think of these value vectors as being
14:39
kind of associated with the corresponding keys.
14:42
For each column in this diagram, you multiply each of the
14:45
value vectors by the corresponding weight in that column.
14:50
For example here, under the embedding of Creature,
14:52
you would be adding large proportions of the value vectors for Fluffy and Blue,
14:57
while all of the other value vectors get zeroed out, or at least nearly zeroed out.
15:02
And then finally, the way to actually update the embedding associated with this column,
15:06
previously encoding some context-free meaning of Creature,
15:09
you add together all of these rescaled values in the column,
15:13
producing a change that you want to add, that I'll label delta-e,
15:16
and then you add that to the original embedding.
15:19
Hopefully what results is a more refined vector encoding the more
15:23
contextually rich meaning, like that of a fluffy blue creature.
15:27
And of course you don't just do this to one embedding,
15:30
you apply the same weighted sum across all of the columns in this picture,
15:34
producing a sequence of changes, adding all of those changes to the corresponding
15:38
embeddings, produces a full sequence of more refined embeddings popping out
15:42
of the attention block.
15:44
Zooming out, this whole process is what you would describe as a single head of attention.
15:49
As I've described things so far, this process is parameterized by three distinct
15:54
matrices, all filled with tunable parameters, the key, the query, and the value.
15:59
I want to take a moment to continue what we started in the last chapter,
16:02
with the scorekeeping where we count up the total number of model parameters using the
16:07
numbers from GPT-3.
16:09
These key and query matrices each have 12,288 columns, matching the embedding dimension,
16:15
and 128 rows, matching the dimension of that smaller key query space.
16:20
This gives us an additional 1.5 million or so parameters for each one.
16:24
If you look at that value matrix by contrast, the way I've described things so
16:30
far would suggest that it's a square matrix that has 12,288 columns and 12,288 rows,
16:35
since both its inputs and outputs live in this very large embedding space.
16:41
If true, that would mean about 150 million added parameters.
16:45
And to be clear, you could do that.
16:47
You could devote orders of magnitude more parameters
16:49
to the value map than to the key and query.
16:52
But in practice, it is much more efficient if instead you make
16:54
it so that the number of parameters devoted to this value map
16:57
is the same as the number devoted to the key and the query.
17:01
This is especially relevant in the setting of
17:03
running multiple attention heads in parallel.
17:06
The way this looks is that the value map is factored as a product of two smaller matrices.
17:11
Conceptually, I would still encourage you to think about the overall linear map,
17:15
one with inputs and outputs, both in this larger embedding space,
17:18
for example taking the embedding of blue to this blueness direction that you would
17:23
add to nouns.
17:27
The first matrix on the right here has a smaller number of rows,
17:30
typically the same size as the key-query space
17:33
What this means is you can think of it as mapping the
17:35
large embedding vectors down to a much smaller space.
17:39
This is not the conventional naming, but I'm going to call this the value down matrix.
17:43
The second matrix maps from this smaller space back up to the embedding space,
17:47
producing the vectors that you use to make the actual updates.
17:51
I'm going to call this one the value up matrix, which again is not conventional.
17:55
The way that you would see this written in most papers looks a little different.
17:58
I'll talk about it in a minute.
17:59
In my opinion, it tends to make things a little more conceptually confusing.
18:03
To throw in linear algebra jargon here, what we're basically doing is
18:06
constraining the overall value map to be a low rank transformation.
18:11
Turning back to the parameter count, all four of these matrices have the same size,
18:16
and adding them all up we get about 6.3 million parameters for one attention head.
18:22
As a quick side note, to be a little more accurate,
18:24
everything described so far is what people would call a self-attention head,
18:27
to distinguish it from a variation that comes up in other models that's
18:30
called cross-attention.
18:32
This isn't relevant to our GPT example, but if you're curious,
18:35
cross-attention involves models that process two distinct types of data,
18:39
like text in one language and text in another language that's part of an
18:43
ongoing generation of a translation, or maybe audio input of speech and an
18:47
ongoing transcription.
18:50
A cross-attention head looks almost identical.
18:52
The only difference is that the key and query maps act on different data sets.
18:57
In a model doing translation, for example, the keys might come from one language,
19:02
while the queries come from another, and the attention pattern could describe
19:06
which words from one language correspond to which words in another.
19:10
And in this setting there would typically be no masking,
19:12
since there's not really any notion of later tokens affecting earlier ones.
19:17
Staying focused on self-attention though, if you understood everything so far,
19:20
and if you were to stop here, you would come away with the essence of what attention
19:24
really is.
19:25
All that's really left to us is to lay out the sense
19:28
in which you do this many many different times.
19:32
In our central example we focused on adjectives updating nouns,
19:35
but of course there are lots of different ways that context can influence the
19:38
meaning of a word.
19:40
If the words they crashed the preceded the word car,
19:43
it has implications for the shape and structure of that car.
19:47
And a lot of associations might be less grammatical.
19:49
If the word wizard is anywhere in the same passage as Harry,
19:52
it suggests that this might be referring to Harry Potter,
19:55
whereas if instead the words Queen, Sussex, and William were in that passage,
19:59
then perhaps the embedding of Harry should instead be updated to refer to the prince.
20:05
For every different type of contextual updating that you might imagine,
20:08
the parameters of these key and query matrices would be different to
20:11
capture the different attention patterns, and the parameters of our
20:15
value map would be different based on what should be added to the embeddings.
20:19
And again, in practice the true behavior of these maps is much more
20:23
difficult to interpret, where the weights are set to do whatever the
20:26
model needs them to do to best accomplish its goal of predicting the next token.
20:31
As I said before, everything we described is a single head of attention,
20:35
and a full attention block inside a transformer consists of what's
20:38
called multi-headed attention, where you run a lot of these operations in parallel,
20:43
each with its own distinct key query and value maps.
20:47
GPT-3 for example uses 96 attention heads inside each block.
20:52
Considering that each one is already a bit confusing,
20:54
it's certainly a lot to hold in your head.
20:56
Just to spell it all out very explicitly, this means you have 96 distinct
21:01
key and query matrices producing 96 distinct attention patterns.
21:05
Then each head has its own distinct value matrices
21:08
used to produce 96 sequences of value vectors.
21:12
These are all added together using the corresponding attention patterns as weights.
21:17
What this means is that for each position in the context, each token,
21:21
every one of these heads produces a proposed change to be added to the embedding in
21:26
that position.
21:27
So what you do is you sum together all of those proposed changes, one for each head,
21:32
and you add the result to the original embedding of that position.
21:36
This entire sum here would be one slice of what's outputted from this multi-headed
21:41
attention block, a single one of those refined embeddings that pops out the other end
21:47
of it.
21:48
Again, this is a lot to think about, so don't
21:50
worry at all if it takes some time to sink in.
21:52
The overall idea is that by running many distinct heads in parallel,
21:56
you're giving the model the capacity to learn many distinct ways that context
22:00
changes meaning.
22:03
Pulling up our running tally for parameter count with 96 heads,
22:07
each including its own variation of these four matrices,
22:10
each block of multi-headed attention ends up with around 600 million parameters.
22:16
There's one added slightly annoying thing that I should really
22:19
mention for any of you who go on to read more about transformers.
22:22
You remember how I said that the value map is factored out into these two
22:25
distinct matrices, which I labeled as the value down and the value up matrices.
22:29
The way that I framed things would suggest that you see this pair of matrices
22:34
inside each attention head, and you could absolutely implement it this way.
22:38
That would be a valid design.
22:40
But the way that you see this written in papers and the way
22:42
that it's implemented in practice looks a little different.
22:45
All of these value up matrices for each head appear stapled together in one giant matrix
22:50
that we call the output matrix, associated with the entire multi-headed attention block.
22:56
And when you see people refer to the value matrix for a given attention head,
23:00
they're typically only referring to this first step,
23:03
the one that I was labeling as the value down projection into the smaller space.
23:08
For the curious among you, I've left an on-screen note about it.
23:11
It's one of those details that runs the risk of distracting
23:13
from the main conceptual points, but I do want to call it out
23:16
just so that you know if you read about this in other sources.
23:19
Setting aside all the technical nuances, in the preview from the last chapter we saw how
23:23
data flowing through a transformer doesn't just flow through a single attention block.
23:28
For one thing, it also goes through these other operations called multi-layer perceptrons.
23:33
We'll talk more about those in the next chapter.
23:35
And then it repeatedly goes through many many copies of both of these operations.
23:39
What this means is that after a given word imbibes some of its context,
23:43
there are many more chances for this more nuanced embedding
23:47
to be influenced by its more nuanced surroundings.
23:50
The further down the network you go, with each embedding taking in more and more
23:54
meaning from all the other embeddings, which themselves are getting more and more
23:59
nuanced, the hope is that there's the capacity to encode higher level and more
24:03
abstract ideas about a given input beyond just descriptors and grammatical structure.
24:07
Things like sentiment and tone and whether it's a poem and what underlying
24:11
scientific truths are relevant to the piece and things like that.
24:16
Turning back one more time to our scorekeeping, GPT-3 includes 96 distinct layers,
24:21
so the total number of key query and value parameters is multiplied by another 96,
24:27
which brings the total sum to just under 58 billion distinct parameters
24:31
devoted to all of the attention heads.
24:34
That is a lot to be sure, but it's only about a third
24:37
of the 175 billion that are in the network in total.
24:41
So even though attention gets all of the attention,
24:44
the majority of parameters come from the blocks sitting in between these steps.
24:48
In the next chapter, you and I will talk more about those
24:50
other blocks and also a lot more about the training process.
24:54
A big part of the story for the success of the attention mechanism is not so much any
24:58
specific kind of behaviour that it enables, but the fact that it's extremely
25:02
parallelizable, meaning that you can run a huge number of computations in a short time
25:07
using GPUs.
25:09
Given that one of the big lessons about deep learning in the last decade or two has
25:13
been that scale alone seems to give huge qualitative improvements in model performance,
25:17
there's a huge advantage to parallelizable architectures that let you do this.
25:22
If you want to learn more about this stuff, I've left lots of links in the description.
25:25
In particular, anything produced by Andrej Karpathy or Chris Ola tend to be pure gold.
25:30
In this video, I wanted to just jump into attention in its current form,
25:33
but if you're curious about more of the history for how we got here
25:36
and how you might reinvent this idea for yourself,
25:38
my friend Vivek just put up a couple videos giving a lot more of that motivation.
25:43
Also, Britt Cruz from the channel The Art of the Problem has a
25:45
really nice video about the history of large language models.
26:04
Thank you.