View on GitHub
GitHub
Neural Networks: Zero to Hero
Let's reproduce GPT-2 (124M)
Loading player
Notes
Transcript
12070 segments
0:00
hi everyone so today we are going to be
0:02
hi everyone so today we are going to be
0:02
hi everyone so today we are going to be continuing our Zero to Hero series and
0:04
continuing our Zero to Hero series and
0:04
continuing our Zero to Hero series and in particular today we are going to
0:05
in particular today we are going to
0:06
in particular today we are going to reproduce the gpt2 model the 124 million
0:09
reproduce the gpt2 model the 124 million
0:09
reproduce the gpt2 model the 124 million version of it so when openi released
0:13
version of it so when openi released
0:13
version of it so when openi released gpt2 this was 2019 and they released it
0:16
gpt2 this was 2019 and they released it
0:16
gpt2 this was 2019 and they released it with this blog post on top of that they
0:19
with this blog post on top of that they
0:19
with this blog post on top of that they released this paper and on top of that
0:21
released this paper and on top of that
0:21
released this paper and on top of that they released this code on GitHub so
0:23
they released this code on GitHub so
0:23
they released this code on GitHub so open a/
0:24
open a/
0:24
open a/ gpt2 now when we talk about reproducing
0:27
gpt2 now when we talk about reproducing
0:27
gpt2 now when we talk about reproducing gpt2 we have to be careful because in
0:29
gpt2 we have to be careful because in
0:29
gpt2 we have to be careful because in particular in this video we're going to
0:30
particular in this video we're going to
0:30
particular in this video we're going to be reproducing the 124 million parameter
0:33
be reproducing the 124 million parameter
0:33
be reproducing the 124 million parameter model so the thing to realize is that
0:35
model so the thing to realize is that
0:35
model so the thing to realize is that there's always a miniseries when these
0:37
there's always a miniseries when these
0:37
there's always a miniseries when these are releases are made so there are the
0:40
are releases are made so there are the
0:40
are releases are made so there are the gpt2 miniseries made up of models at
0:42
gpt2 miniseries made up of models at
0:42
gpt2 miniseries made up of models at different sizes and usually the biggest
0:45
different sizes and usually the biggest
0:45
different sizes and usually the biggest model is called the
0:46
model is called the
0:46
model is called the gpt2 but basically the reason we do that
0:49
gpt2 but basically the reason we do that
0:49
gpt2 but basically the reason we do that is because you can put the model sizes
0:51
is because you can put the model sizes
0:51
is because you can put the model sizes on the x-axis of plots like this and on
0:53
on the x-axis of plots like this and on
0:53
on the x-axis of plots like this and on the Y AIS you put a lot of uh Downstream
0:55
the Y AIS you put a lot of uh Downstream
0:55
the Y AIS you put a lot of uh Downstream metrics that you're interested in like
0:57
metrics that you're interested in like
0:57
metrics that you're interested in like translation summarization question
0:58
translation summarization question
0:58
translation summarization question answering and so on and you can chart
1:00
answering and so on and you can chart
1:00
answering and so on and you can chart out these scaling laws so basically as
1:03
out these scaling laws so basically as
1:03
out these scaling laws so basically as the model size increases you're getting
1:05
the model size increases you're getting
1:05
the model size increases you're getting better and better at Downstream metrics
1:07
better and better at Downstream metrics
1:07
better and better at Downstream metrics and so in particular for
1:09
and so in particular for
1:09
and so in particular for gpt2 if we scroll down in paper there
1:12
gpt2 if we scroll down in paper there
1:12
gpt2 if we scroll down in paper there are four models in the gpt2 miniseries
1:15
are four models in the gpt2 miniseries
1:15
are four models in the gpt2 miniseries starting at 124 million all the way up
1:18
starting at 124 million all the way up
1:18
starting at 124 million all the way up to 1558 million now the reason my
1:22
to 1558 million now the reason my
1:22
to 1558 million now the reason my numbers the way I say them disagree with
1:23
numbers the way I say them disagree with
1:23
numbers the way I say them disagree with this table is that this table is wrong
1:25
this table is that this table is wrong
1:25
this table is that this table is wrong if you actually go to the uh gpt2 uh
1:29
if you actually go to the uh gpt2 uh
1:29
if you actually go to the uh gpt2 uh GitHub repo they sort of say that um
1:32
GitHub repo they sort of say that um
1:32
GitHub repo they sort of say that um there was an error in how they added up
1:33
there was an error in how they added up
1:33
there was an error in how they added up the parameters but basically this is the
1:35
the parameters but basically this is the
1:35
the parameters but basically this is the 124 million parameter model Etc so the
1:38
124 million parameter model Etc so the
1:38
124 million parameter model Etc so the 124 million parameter had 12 layers in
1:40
124 million parameter had 12 layers in
1:40
124 million parameter had 12 layers in the Transformer and it had 768 channels
1:44
the Transformer and it had 768 channels
1:44
the Transformer and it had 768 channels in the Transformer 768 dimensions and
1:47
in the Transformer 768 dimensions and
1:47
in the Transformer 768 dimensions and I'm going to be assuming some
1:48
I'm going to be assuming some
1:48
I'm going to be assuming some familiarity with what these terms mean
1:50
familiarity with what these terms mean
1:50
familiarity with what these terms mean because I covered all of this in my
1:51
because I covered all of this in my
1:51
because I covered all of this in my previous video let's build gpt2 uh let's
1:54
previous video let's build gpt2 uh let's
1:54
previous video let's build gpt2 uh let's build GPT from scratch so I covered that
1:56
build GPT from scratch so I covered that
1:56
build GPT from scratch so I covered that in the previous video in this playlist
1:59
in the previous video in this playlist
1:59
in the previous video in this playlist now if we do everything correctly and
2:01
now if we do everything correctly and
2:01
now if we do everything correctly and everything works out well by the end of
2:03
everything works out well by the end of
2:03
everything works out well by the end of this video we're going to see something
2:04
this video we're going to see something
2:04
this video we're going to see something like this where we're looking at the
2:06
like this where we're looking at the
2:06
like this where we're looking at the validation loss which basically um
2:09
validation loss which basically um
2:10
validation loss which basically um measures how good we are at predicting
2:11
measures how good we are at predicting
2:11
measures how good we are at predicting the next token in a sequence on some
2:13
the next token in a sequence on some
2:13
the next token in a sequence on some validation data that the model has not
2:15
validation data that the model has not
2:15
validation data that the model has not seen during training and we see that we
2:17
seen during training and we see that we
2:17
seen during training and we see that we go from doing that task not very well
2:20
go from doing that task not very well
2:20
go from doing that task not very well because we're initializing from scratch
2:22
because we're initializing from scratch
2:22
because we're initializing from scratch all the way to doing that task quite
2:23
all the way to doing that task quite
2:23
all the way to doing that task quite well um by the end of the training and
2:26
well um by the end of the training and
2:26
well um by the end of the training and hopefully we're going to beat the gpt2
2:28
hopefully we're going to beat the gpt2
2:28
hopefully we're going to beat the gpt2 uh 124 M model
2:30
uh 124 M model
2:30
uh 124 M model now previously when they were working on
2:32
now previously when they were working on
2:32
now previously when they were working on this this is already 5 years ago so this
2:35
this this is already 5 years ago so this
2:35
this this is already 5 years ago so this was probably a fairly complicated
2:36
was probably a fairly complicated
2:36
was probably a fairly complicated optimization at the time and the gpus
2:38
optimization at the time and the gpus
2:38
optimization at the time and the gpus and the compute was a lot smaller today
2:41
and the compute was a lot smaller today
2:41
and the compute was a lot smaller today you can reproduce this model in roughly
2:42
you can reproduce this model in roughly
2:42
you can reproduce this model in roughly an hour or probably less even and it
2:45
an hour or probably less even and it
2:45
an hour or probably less even and it will cost you about 10 bucks if you want
2:47
will cost you about 10 bucks if you want
2:47
will cost you about 10 bucks if you want to do this on the cloud uh Cloud Compu a
2:49
to do this on the cloud uh Cloud Compu a
2:49
to do this on the cloud uh Cloud Compu a sort of computer that you can all rent
2:52
sort of computer that you can all rent
2:52
sort of computer that you can all rent and if you pay $10 for that computer you
2:54
and if you pay $10 for that computer you
2:54
and if you pay $10 for that computer you wait about an hour or less you can
2:56
wait about an hour or less you can
2:56
wait about an hour or less you can actually achieve a model that is as good
2:58
actually achieve a model that is as good
2:58
actually achieve a model that is as good as this model that open ey released and
3:02
as this model that open ey released and
3:02
as this model that open ey released and uh one more thing to mention is unlike
3:04
uh one more thing to mention is unlike
3:04
uh one more thing to mention is unlike many other models open ey did release
3:06
many other models open ey did release
3:06
many other models open ey did release the weights for gpt2 so those weights
3:08
the weights for gpt2 so those weights
3:08
the weights for gpt2 so those weights are all available in this repository but
3:11
are all available in this repository but
3:11
are all available in this repository but the gpt2 paper is not always as good
3:14
the gpt2 paper is not always as good
3:14
the gpt2 paper is not always as good with all of the details of training so
3:16
with all of the details of training so
3:16
with all of the details of training so in addition to the gpt2 paper we're
3:17
in addition to the gpt2 paper we're
3:18
in addition to the gpt2 paper we're going to be referencing the gpt3 paper
3:20
going to be referencing the gpt3 paper
3:20
going to be referencing the gpt3 paper which is a lot more Concrete in a lot of
3:22
which is a lot more Concrete in a lot of
3:22
which is a lot more Concrete in a lot of the hyp parameters and optimization
3:24
the hyp parameters and optimization
3:24
the hyp parameters and optimization settings and so on um and it's not a
3:27
settings and so on um and it's not a
3:27
settings and so on um and it's not a huge departure in the architecture from
3:29
huge departure in the architecture from
3:29
huge departure in the architecture from the GPT 2 uh version of the model so
3:31
the GPT 2 uh version of the model so
3:31
the GPT 2 uh version of the model so we're going to be referencing both gpt2
3:33
we're going to be referencing both gpt2
3:33
we're going to be referencing both gpt2 and gpt3 as we try to reproduce gpt2 124
3:36
and gpt3 as we try to reproduce gpt2 124
3:36
and gpt3 as we try to reproduce gpt2 124 M uh so let's go so the first thing I
3:39
M uh so let's go so the first thing I
3:40
M uh so let's go so the first thing I would like to do is actually start at
3:41
would like to do is actually start at
3:41
would like to do is actually start at the end or at the Target so in other
3:43
the end or at the Target so in other
3:43
the end or at the Target so in other words let's load the GPT to 124 M model
3:47
words let's load the GPT to 124 M model
3:47
words let's load the GPT to 124 M model as it was released by openi and maybe
3:48
as it was released by openi and maybe
3:48
as it was released by openi and maybe take it for a spin let's sample some
3:50
take it for a spin let's sample some
3:50
take it for a spin let's sample some tokens from it now the issue with that
3:52
tokens from it now the issue with that
3:52
tokens from it now the issue with that is when you go into the code base of
3:54
is when you go into the code base of
3:54
is when you go into the code base of gpt2 and you go into the source and you
3:56
gpt2 and you go into the source and you
3:56
gpt2 and you go into the source and you click in on the model. pi you'll realize
3:58
click in on the model. pi you'll realize
3:58
click in on the model. pi you'll realize that actually this is using tensorflow
4:01
that actually this is using tensorflow
4:01
that actually this is using tensorflow so the original gpt2 code here was
4:03
so the original gpt2 code here was
4:03
so the original gpt2 code here was written in tensor flow which is
4:06
written in tensor flow which is
4:06
written in tensor flow which is um you know not let's just say not used
4:09
um you know not let's just say not used
4:09
um you know not let's just say not used as much anymore um so we'd like to use
4:12
as much anymore um so we'd like to use
4:12
as much anymore um so we'd like to use pytorch uh because it's a lot friendlier
4:14
pytorch uh because it's a lot friendlier
4:14
pytorch uh because it's a lot friendlier easier and I just personally like a lot
4:15
easier and I just personally like a lot
4:16
easier and I just personally like a lot more the problem with that is the
4:17
more the problem with that is the
4:17
more the problem with that is the initial code is intenser flow we'd like
4:19
initial code is intenser flow we'd like
4:19
initial code is intenser flow we'd like to use pytorch so instead uh to get the
4:21
to use pytorch so instead uh to get the
4:21
to use pytorch so instead uh to get the target we're going to use the hugging
4:23
target we're going to use the hugging
4:23
target we're going to use the hugging face Transformers um code which I like a
4:26
face Transformers um code which I like a
4:27
face Transformers um code which I like a lot more so when you go into the
4:28
lot more so when you go into the
4:28
lot more so when you go into the Transformers source Transformers models
4:30
Transformers source Transformers models
4:30
Transformers source Transformers models gpt2 modeling gpt2 Pi you will see that
4:33
gpt2 modeling gpt2 Pi you will see that
4:33
gpt2 modeling gpt2 Pi you will see that they have the gpt2 implementation of
4:35
they have the gpt2 implementation of
4:35
they have the gpt2 implementation of that Transformer here in this
4:37
that Transformer here in this
4:37
that Transformer here in this file um and it's like medium readable
4:42
file um and it's like medium readable
4:42
file um and it's like medium readable but not fully readable um but what it
4:45
but not fully readable um but what it
4:45
but not fully readable um but what it does is it did all the work of
4:47
does is it did all the work of
4:47
does is it did all the work of converting all those weights uh from
4:50
converting all those weights uh from
4:50
converting all those weights uh from tensor flow to pytorch Friendly and so
4:52
tensor flow to pytorch Friendly and so
4:52
tensor flow to pytorch Friendly and so it's much easier to load and work with
4:54
it's much easier to load and work with
4:54
it's much easier to load and work with so in particular we can look at the
4:56
so in particular we can look at the
4:56
so in particular we can look at the gpt2 um model here and we can load it
4:59
gpt2 um model here and we can load it
4:59
gpt2 um model here and we can load it using hugging face Transformers so
5:01
using hugging face Transformers so
5:01
using hugging face Transformers so swinging over this is what that looks
5:03
swinging over this is what that looks
5:03
swinging over this is what that looks like from Transformers import the DP GT2
5:07
like from Transformers import the DP GT2
5:07
like from Transformers import the DP GT2 LM head model and then from pre-train
5:12
LM head model and then from pre-train
5:12
LM head model and then from pre-train gpt2 uh now one awkward thing about this
5:15
gpt2 uh now one awkward thing about this
5:15
gpt2 uh now one awkward thing about this is that when you do gpt2 as the model
5:17
is that when you do gpt2 as the model
5:17
is that when you do gpt2 as the model that we're loading this actually is the
5:18
that we're loading this actually is the
5:19
that we're loading this actually is the 124 million parameter model if you want
5:22
124 million parameter model if you want
5:22
124 million parameter model if you want the actual the gpt2 the 1.5 billion then
5:25
the actual the gpt2 the 1.5 billion then
5:25
the actual the gpt2 the 1.5 billion then you actually want to do- XL so this is
5:28
you actually want to do- XL so this is
5:28
you actually want to do- XL so this is the 12 4 M our Target now what we're
5:32
the 12 4 M our Target now what we're
5:32
the 12 4 M our Target now what we're doing is when we actually get this we're
5:33
doing is when we actually get this we're
5:33
doing is when we actually get this we're initializing the uh pytorch NN module as
5:37
initializing the uh pytorch NN module as
5:37
initializing the uh pytorch NN module as defined here in this
5:38
defined here in this
5:38
defined here in this class from it I want to get just the
5:41
class from it I want to get just the
5:41
class from it I want to get just the state dict which is just a raw tensors
5:44
state dict which is just a raw tensors
5:44
state dict which is just a raw tensors so we just have um the tensors of that
5:46
so we just have um the tensors of that
5:46
so we just have um the tensors of that file and by the way here this is a
5:49
file and by the way here this is a
5:49
file and by the way here this is a jupyter notebook uh but this is jupyter
5:51
jupyter notebook uh but this is jupyter
5:51
jupyter notebook uh but this is jupyter notebook running inside vs code uh so I
5:54
notebook running inside vs code uh so I
5:54
notebook running inside vs code uh so I like to work with it all in a single
5:55
like to work with it all in a single
5:56
like to work with it all in a single sort of interface so I like to use vs
5:57
sort of interface so I like to use vs
5:57
sort of interface so I like to use vs code so this is the jupyter notebook
6:00
code so this is the jupyter notebook
6:00
code so this is the jupyter notebook extension inside the es
6:03
extension inside the es
6:03
extension inside the es code so when we get the state dick this
6:06
code so when we get the state dick this
6:06
code so when we get the state dick this is just a dict so we can print the key
6:09
is just a dict so we can print the key
6:09
is just a dict so we can print the key and the value which is the tensor and
6:11
and the value which is the tensor and
6:11
and the value which is the tensor and let's just look at the shapes so these
6:13
let's just look at the shapes so these
6:13
let's just look at the shapes so these are sort of
6:14
are sort of
6:14
are sort of the uh different parameters inside the
6:17
the uh different parameters inside the
6:17
the uh different parameters inside the gbt2 model and their shape so the W
6:22
gbt2 model and their shape so the W
6:22
gbt2 model and their shape so the W weight for token
6:25
weight for token
6:25
weight for token embedding is of size
6:27
embedding is of size
6:27
embedding is of size 50257 by 768 where this is coming from
6:30
50257 by 768 where this is coming from
6:31
50257 by 768 where this is coming from is that we have
6:32
is that we have
6:32
is that we have 50257 tokens in the gpt2 vocabulary um
6:37
50257 tokens in the gpt2 vocabulary um
6:37
50257 tokens in the gpt2 vocabulary um and the tokens by the way these are
6:39
and the tokens by the way these are
6:39
and the tokens by the way these are exactly the tokens that we spoken about
6:40
exactly the tokens that we spoken about
6:40
exactly the tokens that we spoken about in the previous video on my tokenization
6:43
in the previous video on my tokenization
6:43
in the previous video on my tokenization Series so the previous videos just
6:45
Series so the previous videos just
6:45
Series so the previous videos just before this I go into a ton of detail on
6:47
before this I go into a ton of detail on
6:47
before this I go into a ton of detail on tokenization gpt2 tokenizer happens to
6:49
tokenization gpt2 tokenizer happens to
6:49
tokenization gpt2 tokenizer happens to have this many tokens for each
6:53
have this many tokens for each
6:53
have this many tokens for each token we have a 768 dimensional
6:56
token we have a 768 dimensional
6:56
token we have a 768 dimensional embedding that is the distributed
6:58
embedding that is the distributed
6:58
embedding that is the distributed representation that stands in for that
7:01
representation that stands in for that
7:01
representation that stands in for that token so each token is a little string
7:03
token so each token is a little string
7:03
token so each token is a little string piece and then the 768 numbers are the
7:06
piece and then the 768 numbers are the
7:06
piece and then the 768 numbers are the vector that represents that
7:08
vector that represents that
7:08
vector that represents that token and so this is just our lookup
7:10
token and so this is just our lookup
7:10
token and so this is just our lookup table for tokens and then here we have
7:13
table for tokens and then here we have
7:13
table for tokens and then here we have the lookup table for the positions so
7:16
the lookup table for the positions so
7:16
the lookup table for the positions so because gbt2 has a maximum sequence
7:18
because gbt2 has a maximum sequence
7:18
because gbt2 has a maximum sequence length of
7:19
length of
7:19
length of 1024 we have up to 1,24 positions that
7:23
1024 we have up to 1,24 positions that
7:23
1024 we have up to 1,24 positions that each token can be attending to in the
7:25
each token can be attending to in the
7:25
each token can be attending to in the past and every one of those positions in
7:28
past and every one of those positions in
7:28
past and every one of those positions in gpd2 has a fixed Vector of
7:31
gpd2 has a fixed Vector of
7:31
gpd2 has a fixed Vector of 768 that is learned by
7:33
768 that is learned by
7:33
768 that is learned by optimization um and so this is the
7:36
optimization um and so this is the
7:36
optimization um and so this is the position embedding and the token
7:38
position embedding and the token
7:38
position embedding and the token embedding um and then everything here is
7:41
embedding um and then everything here is
7:41
embedding um and then everything here is just the other weights and biases and
7:43
just the other weights and biases and
7:43
just the other weights and biases and everything else of this
7:45
everything else of this
7:45
everything else of this Transformer so when you just take for
7:47
Transformer so when you just take for
7:47
Transformer so when you just take for example the positional embeddings and
7:49
example the positional embeddings and
7:49
example the positional embeddings and flatten it out and take just the 20
7:50
flatten it out and take just the 20
7:51
flatten it out and take just the 20 elements you can see that these are just
7:52
elements you can see that these are just
7:52
elements you can see that these are just the parameters these are weights floats
7:56
the parameters these are weights floats
7:56
the parameters these are weights floats just we can take and we can plot them so
7:59
just we can take and we can plot them so
7:59
just we can take and we can plot them so these are the position embeddings and we
8:01
these are the position embeddings and we
8:01
these are the position embeddings and we get something like this and you can see
8:03
get something like this and you can see
8:03
get something like this and you can see that this has structure and it has
8:04
that this has structure and it has
8:04
that this has structure and it has structure because what we what we have
8:07
structure because what we what we have
8:07
structure because what we what we have here really is every Row in this
8:10
here really is every Row in this
8:10
here really is every Row in this visualization is a different position a
8:12
visualization is a different position a
8:12
visualization is a different position a fixed absolute position in um the range
8:16
fixed absolute position in um the range
8:16
fixed absolute position in um the range from 0 to
8:17
from 0 to
8:17
from 0 to 1024 and each row here is the
8:19
1024 and each row here is the
8:19
1024 and each row here is the representation of that position and so
8:22
representation of that position and so
8:23
representation of that position and so it has structure because these
8:24
it has structure because these
8:24
it has structure because these positional embeddings end up learning
8:26
positional embeddings end up learning
8:26
positional embeddings end up learning these sinusoids and cosiness um that
8:29
these sinusoids and cosiness um that
8:29
these sinusoids and cosiness um that sort of like represent each of these
8:31
sort of like represent each of these
8:31
sort of like represent each of these positions and uh each row here stands in
8:35
positions and uh each row here stands in
8:35
positions and uh each row here stands in for that position and is processed by
8:36
for that position and is processed by
8:36
for that position and is processed by the Transformer to recover all the
8:38
the Transformer to recover all the
8:38
the Transformer to recover all the relative positions and uh sort of
8:41
relative positions and uh sort of
8:41
relative positions and uh sort of realize which token is where and um
8:44
realize which token is where and um
8:44
realize which token is where and um attend to them depending on their
8:45
attend to them depending on their
8:45
attend to them depending on their position not just their
8:47
position not just their
8:47
position not just their content so when we actually just look
8:49
content so when we actually just look
8:49
content so when we actually just look into an individual column inside these
8:53
into an individual column inside these
8:53
into an individual column inside these and I just grabbed three random columns
8:55
and I just grabbed three random columns
8:55
and I just grabbed three random columns you'll see that for example here we are
8:57
you'll see that for example here we are
8:57
you'll see that for example here we are focusing on every every single um
9:01
focusing on every every single um
9:01
focusing on every every single um Channel and we're looking
9:03
Channel and we're looking
9:03
Channel and we're looking at what that channel is doing as a
9:07
at what that channel is doing as a
9:07
at what that channel is doing as a function of uh position from one from Z
9:11
function of uh position from one from Z
9:11
function of uh position from one from Z to
9:12
to
9:12
to 1223
9:14
1223
9:14
1223 really and we can see that some of these
9:15
really and we can see that some of these
9:15
really and we can see that some of these channels basically like respond more or
9:17
channels basically like respond more or
9:17
channels basically like respond more or less to different parts of the position
9:19
less to different parts of the position
9:19
less to different parts of the position Spectrum so this green channel uh really
9:22
Spectrum so this green channel uh really
9:22
Spectrum so this green channel uh really likes to fire for everything after 200
9:26
likes to fire for everything after 200
9:26
likes to fire for everything after 200 uh up to 800 but not less a lot less and
9:30
uh up to 800 but not less a lot less and
9:30
uh up to 800 but not less a lot less and has a sharp drop off here near zero so
9:33
has a sharp drop off here near zero so
9:33
has a sharp drop off here near zero so who knows what these embeddings are
9:34
who knows what these embeddings are
9:34
who knows what these embeddings are doing and why they are the way they are
9:36
doing and why they are the way they are
9:36
doing and why they are the way they are you can tell for example that because
9:37
you can tell for example that because
9:37
you can tell for example that because they're a bit more Jagged and they're
9:38
they're a bit more Jagged and they're
9:38
they're a bit more Jagged and they're kind of noisy you can tell that this
9:40
kind of noisy you can tell that this
9:40
kind of noisy you can tell that this model was not fully trained and the more
9:43
model was not fully trained and the more
9:43
model was not fully trained and the more trained this model was the more you
9:45
trained this model was the more you
9:45
trained this model was the more you would expect to smooth this out and so
9:47
would expect to smooth this out and so
9:47
would expect to smooth this out and so this is telling you that this is a
9:48
this is telling you that this is a
9:48
this is telling you that this is a little bit of an undertrained model um
9:51
little bit of an undertrained model um
9:51
little bit of an undertrained model um but in principle actually these curves
9:53
but in principle actually these curves
9:53
but in principle actually these curves don't even have to be smooth this should
9:55
don't even have to be smooth this should
9:55
don't even have to be smooth this should just be totally random noise and in fact
9:57
just be totally random noise and in fact
9:57
just be totally random noise and in fact in the beginning of the optimization it
9:58
in the beginning of the optimization it
9:58
in the beginning of the optimization it is complete random noise because this
10:01
is complete random noise because this
10:01
is complete random noise because this position embedding table is initialized
10:03
position embedding table is initialized
10:03
position embedding table is initialized completely at random so in the beginning
10:05
completely at random so in the beginning
10:05
completely at random so in the beginning you have jaggedness and the fact that
10:07
you have jaggedness and the fact that
10:07
you have jaggedness and the fact that you end up with something smooth is
10:09
you end up with something smooth is
10:09
you end up with something smooth is already kind of impressive um that that
10:11
already kind of impressive um that that
10:11
already kind of impressive um that that just falls out of the optimization
10:13
just falls out of the optimization
10:13
just falls out of the optimization because in principle you shouldn't even
10:14
because in principle you shouldn't even
10:14
because in principle you shouldn't even be able to get any single graph out of
10:16
be able to get any single graph out of
10:16
be able to get any single graph out of this that makes sense but we actually
10:18
this that makes sense but we actually
10:18
this that makes sense but we actually get something that looks a little bit
10:19
get something that looks a little bit
10:19
get something that looks a little bit noisy but for the most part looks
10:21
noisy but for the most part looks
10:21
noisy but for the most part looks sinusoidal like um in the original
10:24
sinusoidal like um in the original
10:24
sinusoidal like um in the original Transformer um in the original
10:26
Transformer um in the original
10:26
Transformer um in the original Transformer paper the attention is all
10:28
Transformer paper the attention is all
10:28
Transformer paper the attention is all you need paper the positional embeddings
10:30
you need paper the positional embeddings
10:30
you need paper the positional embeddings are actually initialized and fixed if I
10:32
are actually initialized and fixed if I
10:32
are actually initialized and fixed if I remember correctly to sinusoids and
10:34
remember correctly to sinusoids and
10:34
remember correctly to sinusoids and cosiness of uh different frequencies and
10:37
cosiness of uh different frequencies and
10:37
cosiness of uh different frequencies and that's the positional coding and it's
10:38
that's the positional coding and it's
10:38
that's the positional coding and it's fixed but in gpt2 these are just
10:40
fixed but in gpt2 these are just
10:40
fixed but in gpt2 these are just parameters and they're trained from
10:41
parameters and they're trained from
10:41
parameters and they're trained from scratch just like any other parameter uh
10:44
scratch just like any other parameter uh
10:44
scratch just like any other parameter uh and that seems to work about as well and
10:46
and that seems to work about as well and
10:46
and that seems to work about as well and so what they do is they kind of like
10:47
so what they do is they kind of like
10:47
so what they do is they kind of like recover these sinusoidal like features
10:50
recover these sinusoidal like features
10:50
recover these sinusoidal like features during the
10:52
during the
10:52
during the optimization we can also look at any of
10:54
optimization we can also look at any of
10:54
optimization we can also look at any of the other matrices here so here I took
10:57
the other matrices here so here I took
10:57
the other matrices here so here I took the first layer of the
11:00
the first layer of the
11:00
the first layer of the Transformer and looking at like one of
11:02
Transformer and looking at like one of
11:02
Transformer and looking at like one of its weights and just the first block of
11:05
its weights and just the first block of
11:05
its weights and just the first block of 300 by 300 and you see some structure
11:08
300 by 300 and you see some structure
11:08
300 by 300 and you see some structure but like again like who knows what any
11:10
but like again like who knows what any
11:10
but like again like who knows what any of this is if you're into mechanistic
11:12
of this is if you're into mechanistic
11:12
of this is if you're into mechanistic interpretability you might get a real
11:14
interpretability you might get a real
11:14
interpretability you might get a real kick out of trying to figure out like
11:16
kick out of trying to figure out like
11:16
kick out of trying to figure out like what is going on what is this structure
11:18
what is going on what is this structure
11:18
what is going on what is this structure and what does this all mean but we're
11:19
and what does this all mean but we're
11:19
and what does this all mean but we're not going to be doing that in this video
11:21
not going to be doing that in this video
11:21
not going to be doing that in this video but we definitely see that there's some
11:22
but we definitely see that there's some
11:22
but we definitely see that there's some interesting structure and that's kind of
11:24
interesting structure and that's kind of
11:24
interesting structure and that's kind of cool what we're mostly interested in is
11:26
cool what we're mostly interested in is
11:26
cool what we're mostly interested in is we've loaded the weights of this model
11:28
we've loaded the weights of this model
11:28
we've loaded the weights of this model that was released by open Ai and now
11:30
that was released by open Ai and now
11:30
that was released by open Ai and now using the hogging face Transformers we
11:33
using the hogging face Transformers we
11:33
using the hogging face Transformers we can not just get all the raw weights but
11:35
can not just get all the raw weights but
11:35
can not just get all the raw weights but we can also get the um what they call
11:39
we can also get the um what they call
11:39
we can also get the um what they call Pipeline and sample from it so this is
11:42
Pipeline and sample from it so this is
11:42
Pipeline and sample from it so this is the prefix hello I'm a language model
11:44
the prefix hello I'm a language model
11:44
the prefix hello I'm a language model comma and then we're sampling uh 30
11:47
comma and then we're sampling uh 30
11:47
comma and then we're sampling uh 30 tokens and we getting five sequences and
11:50
tokens and we getting five sequences and
11:50
tokens and we getting five sequences and I ran this and this is what it produced
11:53
I ran this and this is what it produced
11:53
I ran this and this is what it produced um hell language
11:55
um hell language
11:55
um hell language model but what I'm really doing is
11:57
model but what I'm really doing is
11:57
model but what I'm really doing is making a human readable document there
11:59
making a human readable document there
11:59
making a human readable document there are other languages but those are dot
12:01
are other languages but those are dot
12:01
are other languages but those are dot dot dot so you can read through these if
12:03
dot dot so you can read through these if
12:03
dot dot so you can read through these if you like but basically these are five
12:05
you like but basically these are five
12:05
you like but basically these are five different completions of the same prefix
12:07
different completions of the same prefix
12:07
different completions of the same prefix from this uh gbt
12:09
from this uh gbt
12:09
from this uh gbt 2124m now uh if I go here I took this
12:13
2124m now uh if I go here I took this
12:13
2124m now uh if I go here I took this example from here and sadly even though
12:16
example from here and sadly even though
12:16
example from here and sadly even though we are fixing the seed we are getting
12:18
we are fixing the seed we are getting
12:18
we are fixing the seed we are getting different Generations from the snippet
12:21
different Generations from the snippet
12:21
different Generations from the snippet than what they got so presumably the
12:24
than what they got so presumably the
12:24
than what they got so presumably the code changed um but what we see though
12:28
code changed um but what we see though
12:28
code changed um but what we see though at this stage that's important is that
12:29
at this stage that's important is that
12:29
at this stage that's important is that we are getting coherent text so we've
12:32
we are getting coherent text so we've
12:32
we are getting coherent text so we've loaded the model successfully we can
12:34
loaded the model successfully we can
12:34
loaded the model successfully we can look at all its parameters and the keys
12:36
look at all its parameters and the keys
12:36
look at all its parameters and the keys tell us where in the model these come
12:39
tell us where in the model these come
12:39
tell us where in the model these come from and we want to actually write our
12:41
from and we want to actually write our
12:41
from and we want to actually write our own gpt2 class so that we have full
12:43
own gpt2 class so that we have full
12:43
own gpt2 class so that we have full understanding of what's happening there
12:44
understanding of what's happening there
12:44
understanding of what's happening there we don't want to be working with
12:46
we don't want to be working with
12:46
we don't want to be working with something like uh the modeling gpt2 Pi
12:49
something like uh the modeling gpt2 Pi
12:49
something like uh the modeling gpt2 Pi because it's just too complicated we
12:50
because it's just too complicated we
12:50
because it's just too complicated we want to write this from scratch
12:51
want to write this from scratch
12:51
want to write this from scratch ourselves so we're going to be
12:53
ourselves so we're going to be
12:53
ourselves so we're going to be implementing the GPT model here in
12:54
implementing the GPT model here in
12:54
implementing the GPT model here in parallel and as our first task let's
12:57
parallel and as our first task let's
12:57
parallel and as our first task let's load the gpt2 124 M into the class that
13:01
load the gpt2 124 M into the class that
13:01
load the gpt2 124 M into the class that we're going to develop here from scratch
13:04
we're going to develop here from scratch
13:04
we're going to develop here from scratch that's going to give us confidence that
13:06
that's going to give us confidence that
13:06
that's going to give us confidence that we can load the open ey model and
13:08
we can load the open ey model and
13:08
we can load the open ey model and therefore there's a setting of Weights
13:10
therefore there's a setting of Weights
13:10
therefore there's a setting of Weights that exactly is the 124 model but then
13:13
that exactly is the 124 model but then
13:13
that exactly is the 124 model but then of course what we're going to do is
13:14
of course what we're going to do is
13:14
of course what we're going to do is we're going to initialize the model from
13:15
we're going to initialize the model from
13:15
we're going to initialize the model from scratch instead and try try to train it
13:18
scratch instead and try try to train it
13:18
scratch instead and try try to train it ourselves um on a bunch of documents
13:20
ourselves um on a bunch of documents
13:20
ourselves um on a bunch of documents that we're going to get and we're going
13:22
that we're going to get and we're going
13:22
that we're going to get and we're going to try to surpass that model so we're
13:24
to try to surpass that model so we're
13:24
to try to surpass that model so we're going to get different weights and
13:25
going to get different weights and
13:25
going to get different weights and everything's going to look different
13:27
everything's going to look different
13:27
everything's going to look different hopefully better even um
13:29
hopefully better even um
13:29
hopefully better even um but uh we're going to have a lot of
13:31
but uh we're going to have a lot of
13:31
but uh we're going to have a lot of confidence that because we can load the
13:32
confidence that because we can load the
13:32
confidence that because we can load the openi model we are in the same model
13:34
openi model we are in the same model
13:34
openi model we are in the same model family and model class and we just have
13:36
family and model class and we just have
13:36
family and model class and we just have to ReDiscover a good setting of the
13:37
to ReDiscover a good setting of the
13:37
to ReDiscover a good setting of the weights uh but from scratch so let's now
13:41
weights uh but from scratch so let's now
13:41
weights uh but from scratch so let's now write the gbt2 model and let's load the
13:43
write the gbt2 model and let's load the
13:43
write the gbt2 model and let's load the weights and make sure that we can also
13:45
weights and make sure that we can also
13:45
weights and make sure that we can also generate text that looks coherent okay
13:48
generate text that looks coherent okay
13:48
generate text that looks coherent okay so let's now swing over to the attention
13:49
so let's now swing over to the attention
13:49
so let's now swing over to the attention is all un need paper that started
13:51
is all un need paper that started
13:51
is all un need paper that started everything and let's scroll over to the
13:53
everything and let's scroll over to the
13:53
everything and let's scroll over to the model architecture the original
13:55
model architecture the original
13:55
model architecture the original Transformer now remember that gpt2 is
13:57
Transformer now remember that gpt2 is
13:57
Transformer now remember that gpt2 is slightly modified from the or or
13:59
slightly modified from the or or
13:59
slightly modified from the or or Transformer in particular we do not have
14:02
Transformer in particular we do not have
14:02
Transformer in particular we do not have uh the encoder gpt2 is a decoder only
14:05
uh the encoder gpt2 is a decoder only
14:05
uh the encoder gpt2 is a decoder only Transformer as we call it so this entire
14:07
Transformer as we call it so this entire
14:07
Transformer as we call it so this entire encoder here is missing in addition to
14:09
encoder here is missing in addition to
14:09
encoder here is missing in addition to that this cross attention here that was
14:12
that this cross attention here that was
14:12
that this cross attention here that was using that encoder is also missing so we
14:14
using that encoder is also missing so we
14:14
using that encoder is also missing so we delete this entire part everything else
14:17
delete this entire part everything else
14:18
delete this entire part everything else stays almost the same but there are some
14:20
stays almost the same but there are some
14:20
stays almost the same but there are some differences that we're going to uh sort
14:21
differences that we're going to uh sort
14:21
differences that we're going to uh sort of look at here so there are two main
14:26
of look at here so there are two main
14:26
of look at here so there are two main differences when we go to the gb2 page
14:29
differences when we go to the gb2 page
14:29
differences when we go to the gb2 page under 2.3 model we notice that first
14:32
under 2.3 model we notice that first
14:32
under 2.3 model we notice that first there's a reshuffling of the layer Norms
14:34
there's a reshuffling of the layer Norms
14:34
there's a reshuffling of the layer Norms so they change place and second an
14:38
so they change place and second an
14:38
so they change place and second an additional layer normalization was added
14:40
additional layer normalization was added
14:40
additional layer normalization was added here to the final self detention block
14:43
here to the final self detention block
14:43
here to the final self detention block so basically all the layer Norms here
14:46
so basically all the layer Norms here
14:46
so basically all the layer Norms here instead of being after the MLP or after
14:48
instead of being after the MLP or after
14:48
instead of being after the MLP or after the attention they SN before it and an
14:50
the attention they SN before it and an
14:50
the attention they SN before it and an additional layer Norm gets added here
14:52
additional layer Norm gets added here
14:52
additional layer Norm gets added here right before the final
14:54
right before the final
14:54
right before the final classifier so now let's Implement some
14:56
classifier so now let's Implement some
14:56
classifier so now let's Implement some of the first sort of skeleton NN module
14:59
of the first sort of skeleton NN module
14:59
of the first sort of skeleton NN module modules here in our GPT NN module and in
15:02
modules here in our GPT NN module and in
15:02
modules here in our GPT NN module and in particular we're going to try to match
15:04
particular we're going to try to match
15:04
particular we're going to try to match up this schema here that is used by
15:06
up this schema here that is used by
15:06
up this schema here that is used by hugging face Transformers because that
15:08
hugging face Transformers because that
15:08
hugging face Transformers because that will make it much easier to load these
15:10
will make it much easier to load these
15:10
will make it much easier to load these weights from this state dict so we want
15:12
weights from this state dict so we want
15:12
weights from this state dict so we want something that reflects uh this schema
15:15
something that reflects uh this schema
15:15
something that reflects uh this schema here so here's what I came up with
15:19
here so here's what I came up with
15:19
here so here's what I came up with um basically we see that the main
15:22
um basically we see that the main
15:22
um basically we see that the main container here that has all the modules
15:24
container here that has all the modules
15:24
container here that has all the modules is called Transformer so I'm reflecting
15:26
is called Transformer so I'm reflecting
15:26
is called Transformer so I'm reflecting that with an NN module dict and this is
15:29
that with an NN module dict and this is
15:29
that with an NN module dict and this is basically a module that allows you to
15:30
basically a module that allows you to
15:30
basically a module that allows you to index into the subm modules using keys
15:34
index into the subm modules using keys
15:34
index into the subm modules using keys just like a dictionary uh
15:36
just like a dictionary uh
15:36
just like a dictionary uh strings within it we have the weights of
15:39
strings within it we have the weights of
15:39
strings within it we have the weights of the token embeddings WT and that's an N
15:41
the token embeddings WT and that's an N
15:41
the token embeddings WT and that's an N embedding and the weights of the
15:44
embedding and the weights of the
15:44
embedding and the weights of the position embeddings which is also just
15:45
position embeddings which is also just
15:45
position embeddings which is also just an N embedding and if you remember n
15:47
an N embedding and if you remember n
15:47
an N embedding and if you remember n embedding is really just a fancy little
15:49
embedding is really just a fancy little
15:49
embedding is really just a fancy little wrapper module around just a single um
15:53
wrapper module around just a single um
15:53
wrapper module around just a single um single array of numbers a single uh
15:56
single array of numbers a single uh
15:56
single array of numbers a single uh block of numbers just like this it's a
15:58
block of numbers just like this it's a
15:58
block of numbers just like this it's a single tensor and an embedding is a
16:01
single tensor and an embedding is a
16:02
single tensor and an embedding is a glorified um wrapper around a tensor
16:04
glorified um wrapper around a tensor
16:04
glorified um wrapper around a tensor that allows you to access its elements
16:07
that allows you to access its elements
16:07
that allows you to access its elements uh by indexing into the
16:08
uh by indexing into the
16:08
uh by indexing into the rows now in addition to that we see here
16:11
rows now in addition to that we see here
16:11
rows now in addition to that we see here that we have a h and then there's a this
16:14
that we have a h and then there's a this
16:14
that we have a h and then there's a this is index using numbers instead of
16:16
is index using numbers instead of
16:16
is index using numbers instead of indexed using strings so there's a h. 0
16:19
indexed using strings so there's a h. 0
16:19
indexed using strings so there's a h. 0 1 2 Etc all the way up till h. 11 and
16:23
1 2 Etc all the way up till h. 11 and
16:23
1 2 Etc all the way up till h. 11 and that's because there are 12 layers here
16:25
that's because there are 12 layers here
16:26
that's because there are 12 layers here in this Transformer so to reflect that
16:28
in this Transformer so to reflect that
16:28
in this Transformer so to reflect that I'm creating also an H I think that
16:31
I'm creating also an H I think that
16:31
I'm creating also an H I think that probably stands for hidden and instead
16:33
probably stands for hidden and instead
16:33
probably stands for hidden and instead of a module dict this is a model list so
16:35
of a module dict this is a model list so
16:35
of a module dict this is a model list so we can index it using integers exactly
16:37
we can index it using integers exactly
16:37
we can index it using integers exactly as we see here 01 2 Etc and the modular
16:41
as we see here 01 2 Etc and the modular
16:42
as we see here 01 2 Etc and the modular list has a n layer blocks and the blocks
16:46
list has a n layer blocks and the blocks
16:46
list has a n layer blocks and the blocks are yet to be defined in a module in a
16:48
are yet to be defined in a module in a
16:48
are yet to be defined in a module in a bit in addition to that following the
16:50
bit in addition to that following the
16:50
bit in addition to that following the gpt2 paper we have we need an additional
16:53
gpt2 paper we have we need an additional
16:53
gpt2 paper we have we need an additional final layer Norm that we're going to put
16:56
final layer Norm that we're going to put
16:56
final layer Norm that we're going to put in there and then we have the final
16:58
in there and then we have the final
16:58
in there and then we have the final classifier uh the language model head
17:01
classifier uh the language model head
17:01
classifier uh the language model head which um projects from 768 the number of
17:05
which um projects from 768 the number of
17:05
which um projects from 768 the number of embedding dimensions in this GPT all the
17:08
embedding dimensions in this GPT all the
17:08
embedding dimensions in this GPT all the way to the vocab size which is
17:10
way to the vocab size which is
17:10
way to the vocab size which is 50257 and gpt2 uses no bias for this
17:13
50257 and gpt2 uses no bias for this
17:13
50257 and gpt2 uses no bias for this final uh sort of projection so this is
17:16
final uh sort of projection so this is
17:16
final uh sort of projection so this is the skeleton and you can see that it
17:19
the skeleton and you can see that it
17:19
the skeleton and you can see that it reflects this so the wte is the token
17:22
reflects this so the wte is the token
17:22
reflects this so the wte is the token embeddings here it's called output
17:24
embeddings here it's called output
17:24
embeddings here it's called output embedding but it's really the token
17:26
embedding but it's really the token
17:26
embedding but it's really the token embeddings the PE is the positional
17:29
embeddings the PE is the positional
17:29
embeddings the PE is the positional codings uh those two pieces of
17:31
codings uh those two pieces of
17:31
codings uh those two pieces of information as we saw previously are
17:32
information as we saw previously are
17:32
information as we saw previously are going to add and then go into the
17:34
going to add and then go into the
17:34
going to add and then go into the Transformer the H is the all the blocks
17:37
Transformer the H is the all the blocks
17:37
Transformer the H is the all the blocks in Gray and the LNF is this new layer
17:40
in Gray and the LNF is this new layer
17:40
in Gray and the LNF is this new layer that gets added here by the gpt2 model
17:43
that gets added here by the gpt2 model
17:43
that gets added here by the gpt2 model and LM head is this linear part here so
17:47
and LM head is this linear part here so
17:47
and LM head is this linear part here so that's the skeleton of the gpt2 we now
17:50
that's the skeleton of the gpt2 we now
17:50
that's the skeleton of the gpt2 we now have to implement the block okay so
17:52
have to implement the block okay so
17:53
have to implement the block okay so let's now recurse to the block itself so
17:55
let's now recurse to the block itself so
17:55
let's now recurse to the block itself so we want to define the block um so I'll
17:58
we want to define the block um so I'll
17:59
we want to define the block um so I'll start putting them here so the block I
18:02
start putting them here so the block I
18:02
start putting them here so the block I like to write out like
18:04
like to write out like
18:04
like to write out like this uh these are some of the
18:06
this uh these are some of the
18:06
this uh these are some of the initializations and then this is the
18:07
initializations and then this is the
18:07
initializations and then this is the actual forward pass of what this block
18:09
actual forward pass of what this block
18:09
actual forward pass of what this block computes and notice here that there's a
18:12
computes and notice here that there's a
18:12
computes and notice here that there's a change from the Transformer again that
18:14
change from the Transformer again that
18:14
change from the Transformer again that is mentioned in the gpt2 paper so here
18:17
is mentioned in the gpt2 paper so here
18:17
is mentioned in the gpt2 paper so here the layer normalizations are after the
18:20
the layer normalizations are after the
18:20
the layer normalizations are after the application of attention or feed forward
18:22
application of attention or feed forward
18:22
application of attention or feed forward in addition to that note that the
18:24
in addition to that note that the
18:24
in addition to that note that the normalizations are inside the residual
18:26
normalizations are inside the residual
18:26
normalizations are inside the residual stream you see how feed forward is
18:28
stream you see how feed forward is
18:28
stream you see how feed forward is applied and this arrow goes through and
18:30
applied and this arrow goes through and
18:30
applied and this arrow goes through and through the normalization so that means
18:33
through the normalization so that means
18:33
through the normalization so that means that your residual pathway has
18:35
that your residual pathway has
18:35
that your residual pathway has normalizations inside them and this is
18:37
normalizations inside them and this is
18:37
normalizations inside them and this is not very good or desirable uh you
18:39
not very good or desirable uh you
18:39
not very good or desirable uh you actually prefer to have a single uh
18:41
actually prefer to have a single uh
18:42
actually prefer to have a single uh clean residual stream all the way from
18:44
clean residual stream all the way from
18:44
clean residual stream all the way from supervision all the way down to the
18:45
supervision all the way down to the
18:45
supervision all the way down to the inputs the tokens and this is very
18:48
inputs the tokens and this is very
18:48
inputs the tokens and this is very desirable and nice because the gradients
18:51
desirable and nice because the gradients
18:51
desirable and nice because the gradients that flow from the top if you remember
18:53
that flow from the top if you remember
18:54
that flow from the top if you remember from your microad addition just
18:56
from your microad addition just
18:56
from your microad addition just distributes gradients during the
18:58
distributes gradients during the
18:58
distributes gradients during the backwards state to both of its branches
19:00
backwards state to both of its branches
19:00
backwards state to both of its branches equally so addition is a branch in the
19:04
equally so addition is a branch in the
19:04
equally so addition is a branch in the gradients and so that means that the
19:06
gradients and so that means that the
19:06
gradients and so that means that the gradients from the top flows straight to
19:08
gradients from the top flows straight to
19:08
gradients from the top flows straight to the inputs the tokens through the
19:10
the inputs the tokens through the
19:10
the inputs the tokens through the residual Pathways unchanged but then in
19:13
residual Pathways unchanged but then in
19:13
residual Pathways unchanged but then in addition to that the gradient also flows
19:14
addition to that the gradient also flows
19:14
addition to that the gradient also flows through the blocks and the blocks you
19:16
through the blocks and the blocks you
19:17
through the blocks and the blocks you know contribute their own contribution
19:18
know contribute their own contribution
19:18
know contribute their own contribution over time and kick in and change the
19:20
over time and kick in and change the
19:20
over time and kick in and change the optimization over time but basically
19:22
optimization over time but basically
19:22
optimization over time but basically clean residual pathway is desirable from
19:25
clean residual pathway is desirable from
19:25
clean residual pathway is desirable from an optimization perspective and then the
19:28
an optimization perspective and then the
19:28
an optimization perspective and then the this is the pre-normalization version
19:30
this is the pre-normalization version
19:30
this is the pre-normalization version where you see that RX first goes through
19:32
where you see that RX first goes through
19:32
where you see that RX first goes through the layer normalization and then the
19:34
the layer normalization and then the
19:34
the layer normalization and then the attention and then goes uh back out to
19:38
attention and then goes uh back out to
19:38
attention and then goes uh back out to go to the L ration number two and the
19:40
go to the L ration number two and the
19:40
go to the L ration number two and the multia perceptron sometimes also
19:43
multia perceptron sometimes also
19:43
multia perceptron sometimes also referred to as a feed forward Network or
19:44
referred to as a feed forward Network or
19:44
referred to as a feed forward Network or an FFN and then that goes into the
19:47
an FFN and then that goes into the
19:47
an FFN and then that goes into the residual stream again and the one more
19:50
residual stream again and the one more
19:50
residual stream again and the one more thing that is kind of interesting to
19:51
thing that is kind of interesting to
19:51
thing that is kind of interesting to note is that recall that attention is a
19:53
note is that recall that attention is a
19:53
note is that recall that attention is a communication operation it is where all
19:55
communication operation it is where all
19:55
communication operation it is where all the tokens and there's 1,24 tokens lined
19:58
the tokens and there's 1,24 tokens lined
19:58
the tokens and there's 1,24 tokens lined up in a sequence and this is where the
20:00
up in a sequence and this is where the
20:00
up in a sequence and this is where the tokens communicate this is where they
20:02
tokens communicate this is where they
20:02
tokens communicate this is where they exchange information so attention is a
20:06
exchange information so attention is a
20:06
exchange information so attention is a um aggregation function it's a pooling
20:08
um aggregation function it's a pooling
20:08
um aggregation function it's a pooling function it's a weighted sum function it
20:12
function it's a weighted sum function it
20:12
function it's a weighted sum function it is a reduce operation whereas MLP this
20:16
is a reduce operation whereas MLP this
20:16
is a reduce operation whereas MLP this uh MLP here happens at every single
20:18
uh MLP here happens at every single
20:18
uh MLP here happens at every single token individually there's no
20:19
token individually there's no
20:20
token individually there's no information being collected or exchanged
20:21
information being collected or exchanged
20:21
information being collected or exchanged between the tokens so the attention is
20:24
between the tokens so the attention is
20:24
between the tokens so the attention is the reduce and the MLP is the map and
20:27
the reduce and the MLP is the map and
20:27
the reduce and the MLP is the map and what you end up with is that the
20:28
what you end up with is that the
20:28
what you end up with is that the Transformer just ends up just being a
20:30
Transformer just ends up just being a
20:30
Transformer just ends up just being a repeated application of map produce if
20:33
repeated application of map produce if
20:33
repeated application of map produce if you want to think about it that way so
20:36
you want to think about it that way so
20:36
you want to think about it that way so um this is where they communicate and
20:37
um this is where they communicate and
20:37
um this is where they communicate and this is where they think individually
20:39
this is where they think individually
20:39
this is where they think individually about the information that they gathered
20:41
about the information that they gathered
20:41
about the information that they gathered and every one of these blocks uh
20:43
and every one of these blocks uh
20:43
and every one of these blocks uh iteratively refines the um
20:46
iteratively refines the um
20:46
iteratively refines the um representation is at the residual stream
20:48
representation is at the residual stream
20:48
representation is at the residual stream so this is our block um slightly
20:51
so this is our block um slightly
20:51
so this is our block um slightly modified from this picture Okay so let's
20:53
modified from this picture Okay so let's
20:53
modified from this picture Okay so let's now move on to the MLP so the MLP block
20:57
now move on to the MLP so the MLP block
20:57
now move on to the MLP so the MLP block uh I implemented as follows
20:59
uh I implemented as follows
20:59
uh I implemented as follows it is relatively straightforward we
21:00
it is relatively straightforward we
21:00
it is relatively straightforward we basically have two linear projections
21:02
basically have two linear projections
21:02
basically have two linear projections here that are sandwiched in between the
21:05
here that are sandwiched in between the
21:05
here that are sandwiched in between the G
21:06
G
21:06
G nonlinearity so nn. G approximate is 10h
21:11
nonlinearity so nn. G approximate is 10h
21:11
nonlinearity so nn. G approximate is 10h now when we swing on uh swing over to
21:13
now when we swing on uh swing over to
21:13
now when we swing on uh swing over to the Pyro documentation this is n.g and
21:16
the Pyro documentation this is n.g and
21:16
the Pyro documentation this is n.g and it has this format and it has two
21:18
it has this format and it has two
21:18
it has this format and it has two versions the original version of G which
21:20
versions the original version of G which
21:20
versions the original version of G which we'll step into into in a bit and the
21:22
we'll step into into in a bit and the
21:22
we'll step into into in a bit and the approximate version of Galo which we can
21:24
approximate version of Galo which we can
21:24
approximate version of Galo which we can request using
21:25
request using
21:25
request using 10 so as you can see just as a preview
21:28
10 so as you can see just as a preview
21:28
10 so as you can see just as a preview here G is a basically like a reu except
21:32
here G is a basically like a reu except
21:32
here G is a basically like a reu except there's no flat exactly Flat Tail here
21:35
there's no flat exactly Flat Tail here
21:35
there's no flat exactly Flat Tail here at exactly zero but otherwise it looks
21:38
at exactly zero but otherwise it looks
21:38
at exactly zero but otherwise it looks very much like a slightly smoother reu
21:41
very much like a slightly smoother reu
21:41
very much like a slightly smoother reu it comes from this paper here Gan error
21:43
it comes from this paper here Gan error
21:43
it comes from this paper here Gan error linear units and uh you can step through
21:46
linear units and uh you can step through
21:46
linear units and uh you can step through this paper and there's some mathematical
21:48
this paper and there's some mathematical
21:48
this paper and there's some mathematical calac reasoning that leads to an
21:50
calac reasoning that leads to an
21:50
calac reasoning that leads to an interpretation that leads to the
21:51
interpretation that leads to the
21:51
interpretation that leads to the specific formulation it has to do with
21:53
specific formulation it has to do with
21:53
specific formulation it has to do with stochastic radial risers and the
21:56
stochastic radial risers and the
21:56
stochastic radial risers and the expectation of a modification to
21:57
expectation of a modification to
21:57
expectation of a modification to Adaptive dropout so you can read through
21:59
Adaptive dropout so you can read through
21:59
Adaptive dropout so you can read through all of that if you'd like here and
22:01
all of that if you'd like here and
22:01
all of that if you'd like here and there's a little bit of history as to
22:03
there's a little bit of history as to
22:03
there's a little bit of history as to why there is an an approximate version
22:05
why there is an an approximate version
22:05
why there is an an approximate version of G and that comes from this issue here
22:08
of G and that comes from this issue here
22:08
of G and that comes from this issue here as far as I can tell and in this issue
22:11
as far as I can tell and in this issue
22:11
as far as I can tell and in this issue Daniel Hendrix mentions that at the time
22:14
Daniel Hendrix mentions that at the time
22:14
Daniel Hendrix mentions that at the time when they developed this nonlinearity
22:16
when they developed this nonlinearity
22:17
when they developed this nonlinearity the Earth function which you need to
22:18
the Earth function which you need to
22:19
the Earth function which you need to evaluate the exact G was very slow in
22:21
evaluate the exact G was very slow in
22:21
evaluate the exact G was very slow in tensor flow so they ended up basically
22:23
tensor flow so they ended up basically
22:23
tensor flow so they ended up basically developing this approximation and this
22:25
developing this approximation and this
22:25
developing this approximation and this approximation that then ended up being
22:27
approximation that then ended up being
22:27
approximation that then ended up being picked up by Bert and by GP P2 Etc but
22:30
picked up by Bert and by GP P2 Etc but
22:30
picked up by Bert and by GP P2 Etc but today there's no real good reason to use
22:31
today there's no real good reason to use
22:31
today there's no real good reason to use the approximate version you'd prefer to
22:33
the approximate version you'd prefer to
22:33
the approximate version you'd prefer to just use the exact version um because I
22:36
just use the exact version um because I
22:36
just use the exact version um because I my expectation is that there's no big
22:38
my expectation is that there's no big
22:38
my expectation is that there's no big difference anymore and this is kind of
22:40
difference anymore and this is kind of
22:40
difference anymore and this is kind of like a historical um kind of Quirk um
22:43
like a historical um kind of Quirk um
22:43
like a historical um kind of Quirk um but we are trying to reproduce gpt2
22:45
but we are trying to reproduce gpt2
22:45
but we are trying to reproduce gpt2 exactly and gpt2 used the 10h
22:48
exactly and gpt2 used the 10h
22:49
exactly and gpt2 used the 10h approximate version so we prefer to
22:51
approximate version so we prefer to
22:51
approximate version so we prefer to stick with
22:52
stick with
22:52
stick with that um now one other reason to actually
22:55
that um now one other reason to actually
22:55
that um now one other reason to actually just intuitively use G instead of veru
22:57
just intuitively use G instead of veru
22:57
just intuitively use G instead of veru is previously in the in videos in the
22:59
is previously in the in videos in the
22:59
is previously in the in videos in the past we've spoken about the dead reu
23:02
past we've spoken about the dead reu
23:02
past we've spoken about the dead reu neuron problem where in this tale of a
23:04
neuron problem where in this tale of a
23:04
neuron problem where in this tale of a reu if it's exactly flat at zero any
23:07
reu if it's exactly flat at zero any
23:07
reu if it's exactly flat at zero any activations that fall there will get
23:09
activations that fall there will get
23:09
activations that fall there will get exactly zero gradient there's no change
23:11
exactly zero gradient there's no change
23:11
exactly zero gradient there's no change there's no adaptation there's no
23:13
there's no adaptation there's no
23:13
there's no adaptation there's no development of the network if any of
23:15
development of the network if any of
23:15
development of the network if any of these activations end in this flat
23:17
these activations end in this flat
23:17
these activations end in this flat region but the G always contributes a
23:20
region but the G always contributes a
23:20
region but the G always contributes a local gradient and so there's always
23:22
local gradient and so there's always
23:22
local gradient and so there's always going to be a change always going to be
23:23
going to be a change always going to be
23:23
going to be a change always going to be an adaptation and sort of smoothing it
23:25
an adaptation and sort of smoothing it
23:25
an adaptation and sort of smoothing it out ends up empirically working better
23:27
out ends up empirically working better
23:27
out ends up empirically working better in practice as demonstrated in this
23:29
in practice as demonstrated in this
23:29
in practice as demonstrated in this paper and also as demonstrated by it
23:31
paper and also as demonstrated by it
23:31
paper and also as demonstrated by it being picked up by the bird paper gbt2
23:33
being picked up by the bird paper gbt2
23:33
being picked up by the bird paper gbt2 paper and so on so for that reason we
23:35
paper and so on so for that reason we
23:35
paper and so on so for that reason we adopt this nonlinearity uh here in the
23:38
adopt this nonlinearity uh here in the
23:38
adopt this nonlinearity uh here in the 10 in the gbt2 reproduction now in more
23:40
10 in the gbt2 reproduction now in more
23:41
10 in the gbt2 reproduction now in more modern networks also like llama 3 and so
23:43
modern networks also like llama 3 and so
23:43
modern networks also like llama 3 and so on this nonlinearity also further
23:45
on this nonlinearity also further
23:45
on this nonlinearity also further changes uh to swiglo and other variants
23:48
changes uh to swiglo and other variants
23:48
changes uh to swiglo and other variants like that uh but for gpt2 they Ed this
23:50
like that uh but for gpt2 they Ed this
23:50
like that uh but for gpt2 they Ed this approximate
23:51
approximate
23:51
approximate G okay and finally we have the attention
23:54
G okay and finally we have the attention
23:54
G okay and finally we have the attention operation so let me paste in my
23:57
operation so let me paste in my
23:57
operation so let me paste in my attention
24:00
so I know this is a lot so I'm going to
24:02
so I know this is a lot so I'm going to
24:02
so I know this is a lot so I'm going to go through this a bit quickly a bit
24:03
go through this a bit quickly a bit
24:03
go through this a bit quickly a bit slowly but not too slowly because we
24:05
slowly but not too slowly because we
24:05
slowly but not too slowly because we have covered this in the previous video
24:07
have covered this in the previous video
24:07
have covered this in the previous video and I would just point you there um so
24:10
and I would just point you there um so
24:10
and I would just point you there um so this is the attention operation now in
24:12
this is the attention operation now in
24:12
this is the attention operation now in the previous video you will remember
24:13
the previous video you will remember
24:13
the previous video you will remember this is not just attention this is um
24:16
this is not just attention this is um
24:16
this is not just attention this is um multi-headed attention right and so in
24:19
multi-headed attention right and so in
24:19
multi-headed attention right and so in the previous video we had this
24:20
the previous video we had this
24:20
the previous video we had this multi-headed attention module and this
24:23
multi-headed attention module and this
24:23
multi-headed attention module and this implementation made it obvious that
24:25
implementation made it obvious that
24:25
implementation made it obvious that these heads are not actually that
24:26
these heads are not actually that
24:26
these heads are not actually that complicated uh there's basically
24:28
complicated uh there's basically
24:28
complicated uh there's basically in parallel inside every attention block
24:32
in parallel inside every attention block
24:32
in parallel inside every attention block there's multiple heads and they're all
24:33
there's multiple heads and they're all
24:33
there's multiple heads and they're all functioning in parallel and uh their
24:36
functioning in parallel and uh their
24:36
functioning in parallel and uh their outputs are just being concatenated and
24:38
outputs are just being concatenated and
24:38
outputs are just being concatenated and that becomes the output of the
24:40
that becomes the output of the
24:40
that becomes the output of the multi-headed attention so the heads are
24:42
multi-headed attention so the heads are
24:42
multi-headed attention so the heads are just kind of like parallel streams and
24:45
just kind of like parallel streams and
24:45
just kind of like parallel streams and their outputs get
24:46
their outputs get
24:46
their outputs get concatenated and so it was very simple
24:48
concatenated and so it was very simple
24:48
concatenated and so it was very simple and made the head be kind of like U
24:51
and made the head be kind of like U
24:51
and made the head be kind of like U fairly straightforward in terms of its
24:54
fairly straightforward in terms of its
24:54
fairly straightforward in terms of its implementation what happens here is that
24:56
implementation what happens here is that
24:56
implementation what happens here is that instead of having two separate modules
24:58
instead of having two separate modules
24:58
instead of having two separate modules and indeed many more modules that get
24:59
and indeed many more modules that get
24:59
and indeed many more modules that get concatenated all of that is just put
25:01
concatenated all of that is just put
25:01
concatenated all of that is just put into a single uh self attention uh
25:04
into a single uh self attention uh
25:04
into a single uh self attention uh module and instead I'm being very
25:07
module and instead I'm being very
25:07
module and instead I'm being very careful and doing a bunch of transpose
25:10
careful and doing a bunch of transpose
25:10
careful and doing a bunch of transpose split um tensor gymnastics to make this
25:13
split um tensor gymnastics to make this
25:13
split um tensor gymnastics to make this very efficient in pych but fundamentally
25:15
very efficient in pych but fundamentally
25:15
very efficient in pych but fundamentally and algorithmically nothing is different
25:17
and algorithmically nothing is different
25:17
and algorithmically nothing is different from the implementation we saw
25:19
from the implementation we saw
25:19
from the implementation we saw before um in this uh give
25:22
before um in this uh give
25:22
before um in this uh give repository so to remind you very briefly
25:25
repository so to remind you very briefly
25:25
repository so to remind you very briefly and I don't want to go in this uh into
25:27
and I don't want to go in this uh into
25:27
and I don't want to go in this uh into this in too many in too much time but we
25:30
this in too many in too much time but we
25:30
this in too many in too much time but we have these tokens lined up in a sequence
25:32
have these tokens lined up in a sequence
25:32
have these tokens lined up in a sequence and there's 1,20 of them and then each
25:35
and there's 1,20 of them and then each
25:35
and there's 1,20 of them and then each token at this stage of the attention
25:37
token at this stage of the attention
25:37
token at this stage of the attention emits three vectors the query key and
25:40
emits three vectors the query key and
25:40
emits three vectors the query key and the value and first what happens here um
25:44
the value and first what happens here um
25:44
the value and first what happens here um is that the queries and the keys have to
25:46
is that the queries and the keys have to
25:46
is that the queries and the keys have to multiply each other to get sort of the
25:49
multiply each other to get sort of the
25:49
multiply each other to get sort of the attention um amount like how interesting
25:52
attention um amount like how interesting
25:52
attention um amount like how interesting they find each other so they have to
25:54
they find each other so they have to
25:54
they find each other so they have to interact multiplicatively so what we're
25:56
interact multiplicatively so what we're
25:56
interact multiplicatively so what we're doing here is we're calculating the qkv
25:58
doing here is we're calculating the qkv
25:58
doing here is we're calculating the qkv we splitting it and then there's a bunch
26:00
we splitting it and then there's a bunch
26:00
we splitting it and then there's a bunch of gymnastics as I mentioned here and
26:03
of gymnastics as I mentioned here and
26:03
of gymnastics as I mentioned here and the way this works is that we're
26:04
the way this works is that we're
26:04
the way this works is that we're basically making the number of heads and
26:06
basically making the number of heads and
26:06
basically making the number of heads and H into a batch Dimension and so it's a
26:10
H into a batch Dimension and so it's a
26:10
H into a batch Dimension and so it's a batch Dimension just like B so that in
26:12
batch Dimension just like B so that in
26:12
batch Dimension just like B so that in these operations that follow pytorch
26:14
these operations that follow pytorch
26:14
these operations that follow pytorch treats B and NH as batches and it
26:18
treats B and NH as batches and it
26:18
treats B and NH as batches and it applies all the operations on all of
26:20
applies all the operations on all of
26:20
applies all the operations on all of them in parallel in both the batch and
26:22
them in parallel in both the batch and
26:22
them in parallel in both the batch and the
26:23
the
26:23
the heads and the operations that get
26:25
heads and the operations that get
26:25
heads and the operations that get applied are number one the queries and
26:27
applied are number one the queries and
26:27
applied are number one the queries and the keys intera to give us her attention
26:30
the keys intera to give us her attention
26:30
the keys intera to give us her attention this is the autoaggressive mask that
26:32
this is the autoaggressive mask that
26:32
this is the autoaggressive mask that makes sure that the tokens only attend
26:35
makes sure that the tokens only attend
26:35
makes sure that the tokens only attend to tokens before them and never to
26:37
to tokens before them and never to
26:37
to tokens before them and never to tokens in the
26:39
tokens in the
26:39
tokens in the future the softmax here normalizes the
26:41
future the softmax here normalizes the
26:41
future the softmax here normalizes the attention so it sums to one always and
26:45
attention so it sums to one always and
26:45
attention so it sums to one always and then recall from the previous video that
26:47
then recall from the previous video that
26:47
then recall from the previous video that doing the attention Matrix multiply with
26:48
doing the attention Matrix multiply with
26:48
doing the attention Matrix multiply with the values is basically a way to do a
26:50
the values is basically a way to do a
26:50
the values is basically a way to do a weighted sum of the values of the tokens
26:53
weighted sum of the values of the tokens
26:53
weighted sum of the values of the tokens that we found interesting at every
26:55
that we found interesting at every
26:55
that we found interesting at every single token and then the final
26:57
single token and then the final
26:57
single token and then the final transpose conf VI and view is just
26:59
transpose conf VI and view is just
26:59
transpose conf VI and view is just reassembling all of that again and this
27:02
reassembling all of that again and this
27:02
reassembling all of that again and this actually performs the concatenation
27:04
actually performs the concatenation
27:04
actually performs the concatenation operation so you can step through this
27:06
operation so you can step through this
27:06
operation so you can step through this uh slowly if you'd like um but it is
27:08
uh slowly if you'd like um but it is
27:08
uh slowly if you'd like um but it is equivalent mathematically to our
27:10
equivalent mathematically to our
27:10
equivalent mathematically to our previous implementation is just more
27:12
previous implementation is just more
27:12
previous implementation is just more efficient in P torch so that's why I
27:14
efficient in P torch so that's why I
27:14
efficient in P torch so that's why I chose this implementation
27:16
chose this implementation
27:16
chose this implementation instead now in addition to that I'm
27:18
instead now in addition to that I'm
27:18
instead now in addition to that I'm being careful with how I name my
27:19
being careful with how I name my
27:19
being careful with how I name my variables so for example cattin is the
27:22
variables so for example cattin is the
27:22
variables so for example cattin is the same as seaten and so actually our keys
27:25
same as seaten and so actually our keys
27:25
same as seaten and so actually our keys should basically exactly follow the
27:26
should basically exactly follow the
27:27
should basically exactly follow the schema of the hugging face train
27:28
schema of the hugging face train
27:28
schema of the hugging face train Transformers code and that will make it
27:29
Transformers code and that will make it
27:29
Transformers code and that will make it very easy for us to now Port over all
27:32
very easy for us to now Port over all
27:32
very easy for us to now Port over all the weights from exactly this sort of
27:34
the weights from exactly this sort of
27:34
the weights from exactly this sort of naming conventions because all of our
27:36
naming conventions because all of our
27:36
naming conventions because all of our variables are named the same thing but
27:39
variables are named the same thing but
27:39
variables are named the same thing but um at this point we have finished the
27:41
um at this point we have finished the
27:41
um at this point we have finished the gpt2 implementation and what that allows
27:44
gpt2 implementation and what that allows
27:44
gpt2 implementation and what that allows us to do is we don't have to basically
27:46
us to do is we don't have to basically
27:46
us to do is we don't have to basically use uh this file from hugging face which
27:48
use uh this file from hugging face which
27:48
use uh this file from hugging face which is fairly long
27:49
is fairly long
27:50
is fairly long um this
27:52
um this
27:52
um this is uh 2,000 lines of code um instead we
27:57
is uh 2,000 lines of code um instead we
27:57
is uh 2,000 lines of code um instead we just have a less than 100 lines of code
27:59
just have a less than 100 lines of code
27:59
just have a less than 100 lines of code and this is the complete uh gpd2
28:01
and this is the complete uh gpd2
28:01
and this is the complete uh gpd2 implementation so at this stage we
28:02
implementation so at this stage we
28:02
implementation so at this stage we should just be able to take over all the
28:04
should just be able to take over all the
28:04
should just be able to take over all the weights set them and then do generation
28:07
weights set them and then do generation
28:07
weights set them and then do generation so let's see what that looks like okay
28:09
so let's see what that looks like okay
28:09
so let's see what that looks like okay so here I've also changed the GPT config
28:11
so here I've also changed the GPT config
28:11
so here I've also changed the GPT config so that the numbers here the H
28:13
so that the numbers here the H
28:13
so that the numbers here the H parameters agree with the gpt2 124 M
28:15
parameters agree with the gpt2 124 M
28:15
parameters agree with the gpt2 124 M model so the maximum sequence length
28:17
model so the maximum sequence length
28:17
model so the maximum sequence length which I call block size here is 124 the
28:20
which I call block size here is 124 the
28:21
which I call block size here is 124 the number of tokens is 50250 257 which if
28:24
number of tokens is 50250 257 which if
28:25
number of tokens is 50250 257 which if you watch my tokenizer video know that
28:26
you watch my tokenizer video know that
28:27
you watch my tokenizer video know that this is 50,000 m merges BP merges 256
28:31
this is 50,000 m merges BP merges 256
28:31
this is 50,000 m merges BP merges 256 bite tokens the leaves of the BP tree
28:34
bite tokens the leaves of the BP tree
28:35
bite tokens the leaves of the BP tree and one special end of text token that
28:36
and one special end of text token that
28:36
and one special end of text token that delimits different documents and can
28:38
delimits different documents and can
28:38
delimits different documents and can start generation as well and there are
28:41
start generation as well and there are
28:41
start generation as well and there are 12 layers there are 12 heads in the
28:43
12 layers there are 12 heads in the
28:43
12 layers there are 12 heads in the attention and the dimension of the
28:45
attention and the dimension of the
28:45
attention and the dimension of the Transformers was
28:46
Transformers was
28:46
Transformers was 768 so here's how we can now load the
28:49
768 so here's how we can now load the
28:49
768 so here's how we can now load the parameters from hugging face to uh our
28:52
parameters from hugging face to uh our
28:52
parameters from hugging face to uh our code here and initialize the GPT class
28:54
code here and initialize the GPT class
28:54
code here and initialize the GPT class with those parameters so let me just
28:56
with those parameters so let me just
28:56
with those parameters so let me just copy paste a bunch of code
28:59
copy paste a bunch of code
28:59
copy paste a bunch of code here and I'm not going to go through
29:00
here and I'm not going to go through
29:00
here and I'm not going to go through this code too slow too quickly too
29:03
this code too slow too quickly too
29:03
this code too slow too quickly too slowly because um honestly it's not that
29:07
slowly because um honestly it's not that
29:07
slowly because um honestly it's not that interesting it's not that exciting we're
29:08
interesting it's not that exciting we're
29:08
interesting it's not that exciting we're just loading the weights so it's kind of
29:10
just loading the weights so it's kind of
29:10
just loading the weights so it's kind of dry but as I mentioned there are four
29:12
dry but as I mentioned there are four
29:12
dry but as I mentioned there are four models in this miniseries of gpt2 this
29:15
models in this miniseries of gpt2 this
29:15
models in this miniseries of gpt2 this is some of the Jupiter code um code that
29:18
is some of the Jupiter code um code that
29:18
is some of the Jupiter code um code that we had here on the right I'm just pting
29:20
we had here on the right I'm just pting
29:20
we had here on the right I'm just pting it over these are the hyper parameters
29:22
it over these are the hyper parameters
29:22
it over these are the hyper parameters of the gpt2 models uh we're creating the
29:24
of the gpt2 models uh we're creating the
29:24
of the gpt2 models uh we're creating the config object and creating our own model
29:27
config object and creating our own model
29:27
config object and creating our own model and then what's Happening Here is we're
29:28
and then what's Happening Here is we're
29:28
and then what's Happening Here is we're creating the state dict both for our
29:30
creating the state dict both for our
29:30
creating the state dict both for our model and for the hugging face
29:33
model and for the hugging face
29:33
model and for the hugging face model um and then what we're doing here
29:36
model um and then what we're doing here
29:36
model um and then what we're doing here is we're going over the hugging face
29:37
is we're going over the hugging face
29:38
is we're going over the hugging face model keys and we're copying over those
29:42
model keys and we're copying over those
29:42
model keys and we're copying over those tensors and in the process we are kind
29:45
tensors and in the process we are kind
29:45
tensors and in the process we are kind of ignoring a few of the buffers they're
29:47
of ignoring a few of the buffers they're
29:47
of ignoring a few of the buffers they're not parameters they're buffers so for
29:49
not parameters they're buffers so for
29:49
not parameters they're buffers so for example attention dobias uh that's just
29:51
example attention dobias uh that's just
29:51
example attention dobias uh that's just used for the autoaggressive mask and so
29:53
used for the autoaggressive mask and so
29:53
used for the autoaggressive mask and so we are ignoring some of those masks and
29:56
we are ignoring some of those masks and
29:56
we are ignoring some of those masks and uh that's it and then then one
29:58
uh that's it and then then one
29:58
uh that's it and then then one additional kind of annoyance is that
30:00
additional kind of annoyance is that
30:00
additional kind of annoyance is that this comes from the tensorflow repo and
30:02
this comes from the tensorflow repo and
30:02
this comes from the tensorflow repo and I'm not sure how this is a little bit
30:04
I'm not sure how this is a little bit
30:04
I'm not sure how this is a little bit annoying but some of the weights are
30:05
annoying but some of the weights are
30:05
annoying but some of the weights are transposed from what pytorch would want
30:08
transposed from what pytorch would want
30:08
transposed from what pytorch would want and so manually I hardcoded the weights
30:10
and so manually I hardcoded the weights
30:10
and so manually I hardcoded the weights that should be transposed and then we
30:12
that should be transposed and then we
30:12
that should be transposed and then we transpose them if that is so and then we
30:15
transpose them if that is so and then we
30:15
transpose them if that is so and then we return this model so the from
30:18
return this model so the from
30:18
return this model so the from pre-trained is a
30:20
pre-trained is a
30:20
pre-trained is a Constructor or class method in Python
30:23
Constructor or class method in Python
30:23
Constructor or class method in Python that Returns the GPT object if we just
30:26
that Returns the GPT object if we just
30:26
that Returns the GPT object if we just give it the model type which in our case
30:28
give it the model type which in our case
30:28
give it the model type which in our case is gpt2 the smallest model that we're
30:30
is gpt2 the smallest model that we're
30:30
is gpt2 the smallest model that we're interested in so this is the code and
30:33
interested in so this is the code and
30:33
interested in so this is the code and this is how you would use it and um we
30:35
this is how you would use it and um we
30:35
this is how you would use it and um we can pop open the terminal here in vs
30:38
can pop open the terminal here in vs
30:38
can pop open the terminal here in vs code and we can python train gbt2 pi and
30:44
code and we can python train gbt2 pi and
30:44
code and we can python train gbt2 pi and fingers
30:46
crossed okay so we didn't crash and so
30:50
crossed okay so we didn't crash and so
30:50
crossed okay so we didn't crash and so we can load the weights and the biases
30:52
we can load the weights and the biases
30:52
we can load the weights and the biases and everything else into our Ann module
30:55
and everything else into our Ann module
30:55
and everything else into our Ann module but now let's also get additional
30:57
but now let's also get additional
30:57
but now let's also get additional confidence that this is working and
30:58
confidence that this is working and
30:58
confidence that this is working and let's try to actually generate from this
31:00
let's try to actually generate from this
31:00
let's try to actually generate from this model okay now before we can actually
31:01
model okay now before we can actually
31:01
model okay now before we can actually generate from this model we have to be
31:03
generate from this model we have to be
31:03
generate from this model we have to be able to forward it we didn't actually
31:04
able to forward it we didn't actually
31:04
able to forward it we didn't actually write that code yet so here's the
31:06
write that code yet so here's the
31:06
write that code yet so here's the forward
31:08
forward
31:08
forward function so the input to the forward is
31:11
function so the input to the forward is
31:11
function so the input to the forward is going to be our indices our tokens uh
31:13
going to be our indices our tokens uh
31:13
going to be our indices our tokens uh token indices and they are always of
31:16
token indices and they are always of
31:16
token indices and they are always of shape B BYT and so we have batch
31:19
shape B BYT and so we have batch
31:19
shape B BYT and so we have batch dimension of B and then we have the time
31:22
dimension of B and then we have the time
31:22
dimension of B and then we have the time dimension of up to T and the T can't be
31:26
dimension of up to T and the T can't be
31:26
dimension of up to T and the T can't be more than the block size the block size
31:27
more than the block size the block size
31:27
more than the block size the block size is is the maximum sequence length so B
31:30
is is the maximum sequence length so B
31:30
is is the maximum sequence length so B BYT indices arranged is sort of like a
31:32
BYT indices arranged is sort of like a
31:32
BYT indices arranged is sort of like a two-dimensional layout and remember that
31:35
two-dimensional layout and remember that
31:35
two-dimensional layout and remember that basically every single row of this is of
31:37
basically every single row of this is of
31:37
basically every single row of this is of size up to uh block size and this is T
31:41
size up to uh block size and this is T
31:41
size up to uh block size and this is T tokens that are in a sequence and then
31:43
tokens that are in a sequence and then
31:43
tokens that are in a sequence and then we have B independent sequences stacked
31:46
we have B independent sequences stacked
31:46
we have B independent sequences stacked up in a batch so that this is
31:48
up in a batch so that this is
31:48
up in a batch so that this is efficient now here we are forwarding the
31:51
efficient now here we are forwarding the
31:51
efficient now here we are forwarding the position embeddings and the token
31:52
position embeddings and the token
31:52
position embeddings and the token embeddings and this code should be very
31:54
embeddings and this code should be very
31:54
embeddings and this code should be very recognizable from the previous lecture
31:56
recognizable from the previous lecture
31:56
recognizable from the previous lecture so um we basically use uh a range which
31:59
so um we basically use uh a range which
31:59
so um we basically use uh a range which is kind of like a version of range but
32:01
is kind of like a version of range but
32:01
is kind of like a version of range but for pytorch uh and we're iterating from
32:04
for pytorch uh and we're iterating from
32:04
for pytorch uh and we're iterating from Z to T and creating this uh positions uh
32:07
Z to T and creating this uh positions uh
32:07
Z to T and creating this uh positions uh sort of uh indices
32:10
sort of uh indices
32:10
sort of uh indices um and then we are making sure that
32:12
um and then we are making sure that
32:12
um and then we are making sure that they're in the same device as idx
32:14
they're in the same device as idx
32:14
they're in the same device as idx because we're not going to be training
32:15
because we're not going to be training
32:15
because we're not going to be training on only CPU that's going to be too
32:16
on only CPU that's going to be too
32:16
on only CPU that's going to be too inefficient we want to be training on
32:18
inefficient we want to be training on
32:18
inefficient we want to be training on GPU and that's going to come in in a
32:20
GPU and that's going to come in in a
32:20
GPU and that's going to come in in a bit uh then we have the position
32:22
bit uh then we have the position
32:22
bit uh then we have the position embeddings and the token embeddings and
32:24
embeddings and the token embeddings and
32:24
embeddings and the token embeddings and the addition operation of those two now
32:26
the addition operation of those two now
32:26
the addition operation of those two now notice that the position embed are going
32:28
notice that the position embed are going
32:28
notice that the position embed are going to be identical for every single row of
32:31
to be identical for every single row of
32:31
to be identical for every single row of uh of input and so there's broadcasting
32:33
uh of input and so there's broadcasting
32:33
uh of input and so there's broadcasting hidden inside this plus where we have to
32:36
hidden inside this plus where we have to
32:36
hidden inside this plus where we have to create an additional Dimension here and
32:38
create an additional Dimension here and
32:38
create an additional Dimension here and then these two add up because the same
32:40
then these two add up because the same
32:40
then these two add up because the same position embeddings apply at every
32:41
position embeddings apply at every
32:41
position embeddings apply at every single row of our example stacked up in
32:44
single row of our example stacked up in
32:44
single row of our example stacked up in a batch then we forward the Transformer
32:46
a batch then we forward the Transformer
32:46
a batch then we forward the Transformer blocks and finally the last layer norm
32:49
blocks and finally the last layer norm
32:49
blocks and finally the last layer norm and the LM head so what comes out after
32:52
and the LM head so what comes out after
32:52
and the LM head so what comes out after forward is the logits and if the input
32:55
forward is the logits and if the input
32:55
forward is the logits and if the input was B BYT indices then at every single B
32:58
was B BYT indices then at every single B
32:58
was B BYT indices then at every single B by T we will calculate the uh logits for
33:02
by T we will calculate the uh logits for
33:02
by T we will calculate the uh logits for what token comes next in the sequence so
33:05
what token comes next in the sequence so
33:05
what token comes next in the sequence so what is the token B t+1 the one on the
33:09
what is the token B t+1 the one on the
33:09
what is the token B t+1 the one on the right of this token and B app size here
33:12
right of this token and B app size here
33:12
right of this token and B app size here is the number of possible tokens and so
33:16
is the number of possible tokens and so
33:16
is the number of possible tokens and so therefore this is the tensor that we're
33:17
therefore this is the tensor that we're
33:17
therefore this is the tensor that we're going to obtain and these low jits are
33:19
going to obtain and these low jits are
33:19
going to obtain and these low jits are just a softmax away from becoming
33:22
just a softmax away from becoming
33:22
just a softmax away from becoming probabilities so this is the forward
33:25
probabilities so this is the forward
33:25
probabilities so this is the forward pass of the network and now we can get
33:27
pass of the network and now we can get
33:27
pass of the network and now we can get load and so we're going to be able to
33:28
load and so we're going to be able to
33:29
load and so we're going to be able to generate from the model
33:30
generate from the model
33:30
generate from the model imminently okay so now we're going to
33:32
imminently okay so now we're going to
33:32
imminently okay so now we're going to try to set up the identical thing on the
33:34
try to set up the identical thing on the
33:35
try to set up the identical thing on the left here that matches hug and face on
33:36
left here that matches hug and face on
33:36
left here that matches hug and face on the right so here we've sampled from the
33:39
the right so here we've sampled from the
33:39
the right so here we've sampled from the pipeline and we sampled five times up to
33:42
pipeline and we sampled five times up to
33:42
pipeline and we sampled five times up to 30 tokens with the prefix of hello I'm a
33:45
30 tokens with the prefix of hello I'm a
33:45
30 tokens with the prefix of hello I'm a language model and these are the
33:46
language model and these are the
33:46
language model and these are the completions that we achieved so we're
33:48
completions that we achieved so we're
33:48
completions that we achieved so we're going to try to replicate that on the
33:49
going to try to replicate that on the
33:49
going to try to replicate that on the left here so number turn sequences is
33:51
left here so number turn sequences is
33:51
left here so number turn sequences is five max length is 30 so the first thing
33:53
five max length is 30 so the first thing
33:53
five max length is 30 so the first thing we do of course is we initialize our
33:55
we do of course is we initialize our
33:55
we do of course is we initialize our model then we put it into evaluation
33:57
model then we put it into evaluation
33:57
model then we put it into evaluation mode now this is a good practice to put
33:59
mode now this is a good practice to put
33:59
mode now this is a good practice to put the model into eval when you're not
34:01
the model into eval when you're not
34:01
the model into eval when you're not going to be training it you're just
34:02
going to be training it you're just
34:02
going to be training it you're just going to be using it and I don't
34:05
going to be using it and I don't
34:05
going to be using it and I don't actually know if this is doing anything
34:07
actually know if this is doing anything
34:07
actually know if this is doing anything right now for the following reason our
34:09
right now for the following reason our
34:09
right now for the following reason our model up above here contains no modules
34:11
model up above here contains no modules
34:11
model up above here contains no modules or layers that actually have a different
34:14
or layers that actually have a different
34:14
or layers that actually have a different uh Behavior at training or evaluation
34:16
uh Behavior at training or evaluation
34:16
uh Behavior at training or evaluation time so for example Dropout batch norm
34:18
time so for example Dropout batch norm
34:18
time so for example Dropout batch norm and a bunch of other layers have this
34:20
and a bunch of other layers have this
34:20
and a bunch of other layers have this kind of behavior but all of these layers
34:22
kind of behavior but all of these layers
34:22
kind of behavior but all of these layers that we've used here should be identical
34:23
that we've used here should be identical
34:23
that we've used here should be identical in both training and evaluation time um
34:27
in both training and evaluation time um
34:27
in both training and evaluation time um so so potentially model that eval does
34:29
so so potentially model that eval does
34:29
so so potentially model that eval does nothing but then I'm not actually sure
34:31
nothing but then I'm not actually sure
34:31
nothing but then I'm not actually sure if this is the case and maybe pytorch
34:32
if this is the case and maybe pytorch
34:33
if this is the case and maybe pytorch internals uh do some clever things
34:35
internals uh do some clever things
34:35
internals uh do some clever things depending on the evaluation mode uh
34:36
depending on the evaluation mode uh
34:36
depending on the evaluation mode uh inside here the next thing we're doing
34:39
inside here the next thing we're doing
34:39
inside here the next thing we're doing here is we are moving the entire model
34:41
here is we are moving the entire model
34:41
here is we are moving the entire model to Cuda so we're moving this all of the
34:44
to Cuda so we're moving this all of the
34:44
to Cuda so we're moving this all of the tensors to GPU so I'm sshed here to a
34:47
tensors to GPU so I'm sshed here to a
34:47
tensors to GPU so I'm sshed here to a cloud box and I have a bunch of gpus on
34:49
cloud box and I have a bunch of gpus on
34:49
cloud box and I have a bunch of gpus on this box and here I'm moving the entire
34:53
this box and here I'm moving the entire
34:53
this box and here I'm moving the entire model and all of its members and all of
34:54
model and all of its members and all of
34:54
model and all of its members and all of its tensors and everything like that
34:56
its tensors and everything like that
34:56
its tensors and everything like that everything gets shipped off to basically
34:59
everything gets shipped off to basically
34:59
everything gets shipped off to basically a whole separate computer that is
35:01
a whole separate computer that is
35:01
a whole separate computer that is sitting on the GPU and the GPU is
35:03
sitting on the GPU and the GPU is
35:03
sitting on the GPU and the GPU is connected to the uh CPU and they can
35:05
connected to the uh CPU and they can
35:05
connected to the uh CPU and they can communicate but it's basically a whole
35:06
communicate but it's basically a whole
35:06
communicate but it's basically a whole separate computer with its own computer
35:08
separate computer with its own computer
35:08
separate computer with its own computer architecture and it's really well
35:09
architecture and it's really well
35:09
architecture and it's really well catered to parallel processing tasks
35:11
catered to parallel processing tasks
35:11
catered to parallel processing tasks like those of running neural networks so
35:14
like those of running neural networks so
35:14
like those of running neural networks so I'm doing this so that the model lives
35:16
I'm doing this so that the model lives
35:16
I'm doing this so that the model lives on the GPU a whole separate computer and
35:19
on the GPU a whole separate computer and
35:19
on the GPU a whole separate computer and it's just going to make our code a lot
35:20
it's just going to make our code a lot
35:20
it's just going to make our code a lot more efficient because all of this stuff
35:22
more efficient because all of this stuff
35:22
more efficient because all of this stuff runs a lot more efficiently on the
35:25
runs a lot more efficiently on the
35:25
runs a lot more efficiently on the gpus so that's the model
35:29
gpus so that's the model
35:29
gpus so that's the model itself now uh the next thing we want to
35:31
itself now uh the next thing we want to
35:31
itself now uh the next thing we want to do is we want to start with this as the
35:34
do is we want to start with this as the
35:34
do is we want to start with this as the prefix when we do the generation so
35:37
prefix when we do the generation so
35:37
prefix when we do the generation so let's actually create those prefix
35:39
let's actually create those prefix
35:39
let's actually create those prefix tokens so here's the code that I've
35:41
tokens so here's the code that I've
35:41
tokens so here's the code that I've written we're going to import the tich
35:43
written we're going to import the tich
35:43
written we're going to import the tich token library from open Ai and we're
35:45
token library from open Ai and we're
35:45
token library from open Ai and we're going to get the gpt2 encoding so that's
35:48
going to get the gpt2 encoding so that's
35:48
going to get the gpt2 encoding so that's the tokenizer for gpt2 and then we're
35:51
the tokenizer for gpt2 and then we're
35:51
the tokenizer for gpt2 and then we're going to encode this string and get a
35:54
going to encode this string and get a
35:54
going to encode this string and get a list of integers which are the tokens uh
35:57
list of integers which are the tokens uh
35:57
list of integers which are the tokens uh now these integers here should actually
35:59
now these integers here should actually
35:59
now these integers here should actually be fairly straightforward because we can
36:01
be fairly straightforward because we can
36:01
be fairly straightforward because we can just copy paste this string and we can
36:04
just copy paste this string and we can
36:04
just copy paste this string and we can sort of inspect what it is in tick
36:05
sort of inspect what it is in tick
36:05
sort of inspect what it is in tick tokenizer so just pasting that in these
36:08
tokenizer so just pasting that in these
36:08
tokenizer so just pasting that in these are the tokens that are going to come
36:09
are the tokens that are going to come
36:09
are the tokens that are going to come out so this list of integers is what we
36:12
out so this list of integers is what we
36:12
out so this list of integers is what we expect tokens to become and as you
36:15
expect tokens to become and as you
36:15
expect tokens to become and as you recall if you saw my video of course all
36:17
recall if you saw my video of course all
36:17
recall if you saw my video of course all the tokens they're just little string
36:18
the tokens they're just little string
36:19
the tokens they're just little string chunks right so these are this is the
36:21
chunks right so these are this is the
36:21
chunks right so these are this is the chunc of this string into gpt2
36:25
chunc of this string into gpt2
36:25
chunc of this string into gpt2 tokens so once we have those tokens it's
36:27
tokens so once we have those tokens it's
36:27
tokens so once we have those tokens it's a list of integers we can create a torch
36:30
a list of integers we can create a torch
36:30
a list of integers we can create a torch tensor out of it in this case it's eight
36:32
tensor out of it in this case it's eight
36:32
tensor out of it in this case it's eight tokens and then we're going to replicate
36:34
tokens and then we're going to replicate
36:34
tokens and then we're going to replicate these eight tokens for five times to get
36:36
these eight tokens for five times to get
36:36
these eight tokens for five times to get five rows of eight tokens and that is
36:40
five rows of eight tokens and that is
36:40
five rows of eight tokens and that is our initial um input X as I call it here
36:45
our initial um input X as I call it here
36:45
our initial um input X as I call it here and it lives on the GPU as well so X now
36:48
and it lives on the GPU as well so X now
36:48
and it lives on the GPU as well so X now is this idx that we can put into forward
36:52
is this idx that we can put into forward
36:52
is this idx that we can put into forward to get our logits so that we know what
36:55
to get our logits so that we know what
36:55
to get our logits so that we know what comes as the sixth token
36:58
comes as the sixth token
36:58
comes as the sixth token uh sorry as the ninth token in every one
37:01
uh sorry as the ninth token in every one
37:01
uh sorry as the ninth token in every one of these five rows okay and we are now
37:04
of these five rows okay and we are now
37:04
of these five rows okay and we are now ready to generate so let me paste in one
37:05
ready to generate so let me paste in one
37:05
ready to generate so let me paste in one more code block
37:07
more code block
37:07
more code block here um so what's happening here in this
37:09
here um so what's happening here in this
37:09
here um so what's happening here in this code block is we have this x which is of
37:12
code block is we have this x which is of
37:12
code block is we have this x which is of size B BYT right so batch by time and
37:16
size B BYT right so batch by time and
37:16
size B BYT right so batch by time and we're going to be in every iteration of
37:18
we're going to be in every iteration of
37:18
we're going to be in every iteration of this loop we're going to be adding a
37:19
this loop we're going to be adding a
37:19
this loop we're going to be adding a column of new indices into each one of
37:22
column of new indices into each one of
37:22
column of new indices into each one of these rows right and so these are the
37:24
these rows right and so these are the
37:24
these rows right and so these are the new indices and we're appending them to
37:27
new indices and we're appending them to
37:27
new indices and we're appending them to the the sequence as we're sampling so
37:29
the the sequence as we're sampling so
37:29
the the sequence as we're sampling so with each Loop iteration we get one more
37:31
with each Loop iteration we get one more
37:31
with each Loop iteration we get one more column into X and all of the operations
37:34
column into X and all of the operations
37:34
column into X and all of the operations happen in the context manager of torch.
37:36
happen in the context manager of torch.
37:36
happen in the context manager of torch. nograd this is just telling pytorch that
37:38
nograd this is just telling pytorch that
37:38
nograd this is just telling pytorch that we're not going to be calling that
37:39
we're not going to be calling that
37:39
we're not going to be calling that backward on any of this so it doesn't
37:41
backward on any of this so it doesn't
37:41
backward on any of this so it doesn't have to cach all the intermediate
37:43
have to cach all the intermediate
37:43
have to cach all the intermediate tensors it's not going to have to
37:44
tensors it's not going to have to
37:44
tensors it's not going to have to prepare in any way for a potential
37:46
prepare in any way for a potential
37:46
prepare in any way for a potential backward later and this saves a lot of
37:48
backward later and this saves a lot of
37:48
backward later and this saves a lot of space and also possibly uh some time so
37:52
space and also possibly uh some time so
37:52
space and also possibly uh some time so we get our low jits we get the loow jits
37:54
we get our low jits we get the loow jits
37:54
we get our low jits we get the loow jits at only the last location we throw away
37:57
at only the last location we throw away
37:57
at only the last location we throw away all the other low jits uh we don't need
37:59
all the other low jits uh we don't need
37:59
all the other low jits uh we don't need them we only care about the last columns
38:01
them we only care about the last columns
38:01
them we only care about the last columns low jits so this is being wasteful uh
38:04
low jits so this is being wasteful uh
38:04
low jits so this is being wasteful uh but uh this is just kind of like an
38:06
but uh this is just kind of like an
38:06
but uh this is just kind of like an inefficient implementation of
38:08
inefficient implementation of
38:08
inefficient implementation of sampling um so it's correct but
38:10
sampling um so it's correct but
38:10
sampling um so it's correct but inefficient so we get the last column of
38:13
inefficient so we get the last column of
38:13
inefficient so we get the last column of loow jits pass it through soft Max to
38:14
loow jits pass it through soft Max to
38:14
loow jits pass it through soft Max to get our probabilities then here I'm
38:16
get our probabilities then here I'm
38:16
get our probabilities then here I'm doing top case sampling of 50 and I'm
38:18
doing top case sampling of 50 and I'm
38:18
doing top case sampling of 50 and I'm doing that because this is the hugging
38:20
doing that because this is the hugging
38:20
doing that because this is the hugging face default so just looking at the
38:23
face default so just looking at the
38:23
face default so just looking at the hugging face docks here of a pipeline um
38:26
hugging face docks here of a pipeline um
38:26
hugging face docks here of a pipeline um there's a bunch of
38:28
there's a bunch of
38:28
there's a bunch of quarks that go into hugging face and I
38:32
quarks that go into hugging face and I
38:32
quarks that go into hugging face and I mean it's it's kind of a lot honestly
38:34
mean it's it's kind of a lot honestly
38:34
mean it's it's kind of a lot honestly but I guess the important one that I
38:36
but I guess the important one that I
38:36
but I guess the important one that I noticed is that they're using top K by
38:38
noticed is that they're using top K by
38:38
noticed is that they're using top K by default which is 50 and what that does
38:41
default which is 50 and what that does
38:41
default which is 50 and what that does is that uh so that's being used here as
38:43
is that uh so that's being used here as
38:43
is that uh so that's being used here as well and what that does is basically we
38:45
well and what that does is basically we
38:45
well and what that does is basically we want to take our probabilities and we
38:47
want to take our probabilities and we
38:47
want to take our probabilities and we only want to keep the top 50
38:49
only want to keep the top 50
38:49
only want to keep the top 50 probabilities and anything that is lower
38:51
probabilities and anything that is lower
38:51
probabilities and anything that is lower than the 50th probability uh we just
38:54
than the 50th probability uh we just
38:54
than the 50th probability uh we just clamp to zero and renormalize and so
38:56
clamp to zero and renormalize and so
38:56
clamp to zero and renormalize and so that way we are never sampling very rare
38:59
that way we are never sampling very rare
38:59
that way we are never sampling very rare tokens uh the tokens we're going to be
39:01
tokens uh the tokens we're going to be
39:01
tokens uh the tokens we're going to be sampling are always in the top 50 of
39:03
sampling are always in the top 50 of
39:03
sampling are always in the top 50 of most likely tokens and this helps keep
39:05
most likely tokens and this helps keep
39:05
most likely tokens and this helps keep the model kind of on track and it
39:07
the model kind of on track and it
39:07
the model kind of on track and it doesn't blabber on and it doesn't get
39:08
doesn't blabber on and it doesn't get
39:08
doesn't blabber on and it doesn't get lost and doesn't go off the rails as
39:10
lost and doesn't go off the rails as
39:10
lost and doesn't go off the rails as easily uh and it kind of like um sticks
39:13
easily uh and it kind of like um sticks
39:13
easily uh and it kind of like um sticks in the vicinity of likely tokens a lot
39:15
in the vicinity of likely tokens a lot
39:15
in the vicinity of likely tokens a lot better so this is the way to do it in
39:17
better so this is the way to do it in
39:17
better so this is the way to do it in pytorch and you can step through it if
39:18
pytorch and you can step through it if
39:18
pytorch and you can step through it if you like I don't think it's super
39:20
you like I don't think it's super
39:20
you like I don't think it's super insightful so I'll speed through it but
39:22
insightful so I'll speed through it but
39:22
insightful so I'll speed through it but roughly speaking we get this new column
39:24
roughly speaking we get this new column
39:24
roughly speaking we get this new column of of tokens we append them on x and
39:27
of of tokens we append them on x and
39:27
of of tokens we append them on x and basically The Columns of X grow until
39:30
basically The Columns of X grow until
39:30
basically The Columns of X grow until this y Loop gets tripped up and then
39:33
this y Loop gets tripped up and then
39:33
this y Loop gets tripped up and then finally we have an entire X of size um 5
39:38
finally we have an entire X of size um 5
39:38
finally we have an entire X of size um 5 by 30 in this case in this example and
39:41
by 30 in this case in this example and
39:41
by 30 in this case in this example and we can just basically print all those
39:43
we can just basically print all those
39:43
we can just basically print all those individual rows so I'm getting all the
39:46
individual rows so I'm getting all the
39:46
individual rows so I'm getting all the rows I'm getting all the tokens that
39:48
rows I'm getting all the tokens that
39:48
rows I'm getting all the tokens that were sampled and I'm using the decode
39:50
were sampled and I'm using the decode
39:50
were sampled and I'm using the decode function from Tik tokenizer to get back
39:52
function from Tik tokenizer to get back
39:52
function from Tik tokenizer to get back the string which we can print and so
39:55
the string which we can print and so
39:55
the string which we can print and so terminal new terminal
39:59
and let me python train
40:08
gpt2 okay so these are the generations
40:11
gpt2 okay so these are the generations
40:11
gpt2 okay so these are the generations that we're getting hello I'm a language
40:13
that we're getting hello I'm a language
40:13
that we're getting hello I'm a language model not a
40:14
model not a
40:15
model not a program um new line new line Etc hello
40:19
program um new line new line Etc hello
40:19
program um new line new line Etc hello I'm a language model and one of the main
40:20
I'm a language model and one of the main
40:21
I'm a language model and one of the main things that bothers me when they create
40:22
things that bothers me when they create
40:22
things that bothers me when they create languages is how easy it becomes to
40:23
languages is how easy it becomes to
40:23
languages is how easy it becomes to create something that I me so this will
40:26
create something that I me so this will
40:26
create something that I me so this will just like blabber on right in all these
40:27
just like blabber on right in all these
40:27
just like blabber on right in all these cases now one thing you will notice is
40:29
cases now one thing you will notice is
40:29
cases now one thing you will notice is that these Generations are not the
40:31
that these Generations are not the
40:31
that these Generations are not the generations of hugging face here and I
40:35
generations of hugging face here and I
40:35
generations of hugging face here and I can't find the discrepancy to be honest
40:37
can't find the discrepancy to be honest
40:37
can't find the discrepancy to be honest and I didn't fully go through all these
40:39
and I didn't fully go through all these
40:39
and I didn't fully go through all these options but probably there's something
40:40
options but probably there's something
40:40
options but probably there's something else hiding in on addition to the top P
40:43
else hiding in on addition to the top P
40:43
else hiding in on addition to the top P so I'm not able to match it up but just
40:44
so I'm not able to match it up but just
40:45
so I'm not able to match it up but just for correctness um down here Below in
40:47
for correctness um down here Below in
40:47
for correctness um down here Below in the juper notebook and using the hugging
40:49
the juper notebook and using the hugging
40:49
the juper notebook and using the hugging face model so this is the hugging face
40:52
face model so this is the hugging face
40:52
face model so this is the hugging face model here I was I replicated the code
40:56
model here I was I replicated the code
40:56
model here I was I replicated the code and if I do this and I run that then I
40:59
and if I do this and I run that then I
40:59
and if I do this and I run that then I am getting the same results so basically
41:03
am getting the same results so basically
41:03
am getting the same results so basically the model internals are not wrong it's
41:05
the model internals are not wrong it's
41:05
the model internals are not wrong it's just I'm not 100% sure what the pipeline
41:08
just I'm not 100% sure what the pipeline
41:08
just I'm not 100% sure what the pipeline does in hugging face and that's why
41:09
does in hugging face and that's why
41:09
does in hugging face and that's why we're not able to match them up but
41:11
we're not able to match them up but
41:11
we're not able to match them up but otherwise the code is correct and we've
41:13
otherwise the code is correct and we've
41:13
otherwise the code is correct and we've loaded all the um tensors correctly so
41:16
loaded all the um tensors correctly so
41:16
loaded all the um tensors correctly so we're initializing the model correctly
41:17
we're initializing the model correctly
41:18
we're initializing the model correctly and everything here works so long story
41:20
and everything here works so long story
41:20
and everything here works so long story short uh We've Port it all the weights
41:22
short uh We've Port it all the weights
41:22
short uh We've Port it all the weights we initialize the gpt2 this is the exact
41:25
we initialize the gpt2 this is the exact
41:25
we initialize the gpt2 this is the exact opening gpt2 and it can generate
41:27
opening gpt2 and it can generate
41:27
opening gpt2 and it can generate sequences and they look sensible and now
41:30
sequences and they look sensible and now
41:30
sequences and they look sensible and now here of course we're initializing with
41:32
here of course we're initializing with
41:32
here of course we're initializing with gbt2 model weights but now we want to
41:34
gbt2 model weights but now we want to
41:34
gbt2 model weights but now we want to initialize from scratch from random
41:36
initialize from scratch from random
41:36
initialize from scratch from random numbers and we want to actually train a
41:38
numbers and we want to actually train a
41:38
numbers and we want to actually train a model that will give us sequences as
41:40
model that will give us sequences as
41:40
model that will give us sequences as good as or better than these ones in
41:44
good as or better than these ones in
41:44
good as or better than these ones in quality and so that's what we turn to
41:46
quality and so that's what we turn to
41:46
quality and so that's what we turn to next so it turns out that using the
41:48
next so it turns out that using the
41:48
next so it turns out that using the random model is actually fairly
41:49
random model is actually fairly
41:49
random model is actually fairly straightforward because pytorch already
41:51
straightforward because pytorch already
41:51
straightforward because pytorch already initializes our model randomly and by
41:53
initializes our model randomly and by
41:53
initializes our model randomly and by default so when we create the GPT model
41:58
default so when we create the GPT model
41:58
default so when we create the GPT model and the Constructor this is all um all
42:00
and the Constructor this is all um all
42:00
and the Constructor this is all um all of these layers and modules have random
42:03
of these layers and modules have random
42:03
of these layers and modules have random initializers that are there by default
42:05
initializers that are there by default
42:05
initializers that are there by default so when these linear layers get created
42:07
so when these linear layers get created
42:07
so when these linear layers get created and so on there's default Constructors
42:10
and so on there's default Constructors
42:10
and so on there's default Constructors for example using the Javier
42:11
for example using the Javier
42:11
for example using the Javier initialization that we saw in the past
42:13
initialization that we saw in the past
42:13
initialization that we saw in the past uh to construct the weights of these
42:15
uh to construct the weights of these
42:15
uh to construct the weights of these layers and so creating a random model
42:18
layers and so creating a random model
42:18
layers and so creating a random model instead of a gpt2 model is actually
42:20
instead of a gpt2 model is actually
42:20
instead of a gpt2 model is actually fairly straightforward and we would just
42:22
fairly straightforward and we would just
42:22
fairly straightforward and we would just come here and instead we would create
42:24
come here and instead we would create
42:24
come here and instead we would create model equals GPT and then we want to use
42:27
model equals GPT and then we want to use
42:28
model equals GPT and then we want to use the default config GPT config and the
42:31
the default config GPT config and the
42:31
the default config GPT config and the default config uses the 124 M parameters
42:33
default config uses the 124 M parameters
42:33
default config uses the 124 M parameters so this is the random model
42:35
so this is the random model
42:35
so this is the random model initialization and we can run
42:42
it and we should be able to get uh
42:46
it and we should be able to get uh
42:46
it and we should be able to get uh results now the results here of course
42:48
results now the results here of course
42:48
results now the results here of course are total garbage carbal and that's
42:50
are total garbage carbal and that's
42:50
are total garbage carbal and that's because this is random model and so
42:51
because this is random model and so
42:51
because this is random model and so we're just getting all these random
42:53
we're just getting all these random
42:53
we're just getting all these random token string pieces chunked up totally
42:55
token string pieces chunked up totally
42:55
token string pieces chunked up totally at random so that's what we have right
42:57
at random so that's what we have right
42:57
at random so that's what we have right now uh now one more thing I wanted to
42:59
now uh now one more thing I wanted to
42:59
now uh now one more thing I wanted to point out by the way is in case you do
43:01
point out by the way is in case you do
43:01
point out by the way is in case you do not have Cuda available because you
43:03
not have Cuda available because you
43:03
not have Cuda available because you don't have a GPU you can still follow
43:04
don't have a GPU you can still follow
43:04
don't have a GPU you can still follow along with uh with what we're doing here
43:07
along with uh with what we're doing here
43:07
along with uh with what we're doing here uh to some extent uh and probably not to
43:10
uh to some extent uh and probably not to
43:10
uh to some extent uh and probably not to the very end because by the end we're
43:11
the very end because by the end we're
43:11
the very end because by the end we're going to be using multiple gpus and
43:13
going to be using multiple gpus and
43:13
going to be using multiple gpus and actually doing a serious training run uh
43:15
actually doing a serious training run uh
43:15
actually doing a serious training run uh but for now you can actually follow
43:16
but for now you can actually follow
43:16
but for now you can actually follow along decently okay uh so one thing that
43:19
along decently okay uh so one thing that
43:19
along decently okay uh so one thing that I like to do in pytorch is I like to
43:20
I like to do in pytorch is I like to
43:20
I like to do in pytorch is I like to autod detect the device that is
43:22
autod detect the device that is
43:22
autod detect the device that is available to you so in particular you
43:24
available to you so in particular you
43:24
available to you so in particular you could do that like this
43:28
could do that like this
43:28
could do that like this so here we are trying to detect a device
43:30
so here we are trying to detect a device
43:30
so here we are trying to detect a device to run on that has the highest compute
43:32
to run on that has the highest compute
43:32
to run on that has the highest compute capability you can think about it that
43:33
capability you can think about it that
43:33
capability you can think about it that way so by default we start with CPU
43:36
way so by default we start with CPU
43:36
way so by default we start with CPU which of course is available everywhere
43:37
which of course is available everywhere
43:37
which of course is available everywhere because every single computer will have
43:39
because every single computer will have
43:39
because every single computer will have a CPU but then we can try to detect do
43:41
a CPU but then we can try to detect do
43:42
a CPU but then we can try to detect do you have a GPU you so use a Cuda and
43:44
you have a GPU you so use a Cuda and
43:44
you have a GPU you so use a Cuda and then if you don't have a Cuda uh do you
43:47
then if you don't have a Cuda uh do you
43:47
then if you don't have a Cuda uh do you at least have MPS MPS is the back end
43:49
at least have MPS MPS is the back end
43:49
at least have MPS MPS is the back end for Apple silicon so if you have a
43:51
for Apple silicon so if you have a
43:51
for Apple silicon so if you have a Macbook that is fairly new you probably
43:53
Macbook that is fairly new you probably
43:53
Macbook that is fairly new you probably have apple silicon on the inside and
43:55
have apple silicon on the inside and
43:55
have apple silicon on the inside and then that has a GPU that is actually
43:57
then that has a GPU that is actually
43:57
then that has a GPU that is actually fairly capable uh depending on which
43:59
fairly capable uh depending on which
43:59
fairly capable uh depending on which MacBook you have and so you can use MPS
44:01
MacBook you have and so you can use MPS
44:01
MacBook you have and so you can use MPS which will be potentially faster than
44:02
which will be potentially faster than
44:02
which will be potentially faster than CPU and so we can print the device here
44:05
CPU and so we can print the device here
44:05
CPU and so we can print the device here now once we have the device we can
44:07
now once we have the device we can
44:07
now once we have the device we can actually use it in place of Puda so we
44:11
actually use it in place of Puda so we
44:11
actually use it in place of Puda so we just swap it in and notice that here
44:14
just swap it in and notice that here
44:14
just swap it in and notice that here when we call model on X if this x here
44:17
when we call model on X if this x here
44:17
when we call model on X if this x here is on CPU instead of GPU then it will
44:21
is on CPU instead of GPU then it will
44:21
is on CPU instead of GPU then it will work fine because here in the forward
44:23
work fine because here in the forward
44:23
work fine because here in the forward which is where P to will come when we
44:26
which is where P to will come when we
44:26
which is where P to will come when we create a pose we were careful to use the
44:28
create a pose we were careful to use the
44:28
create a pose we were careful to use the device of idx to create this tensor as
44:31
device of idx to create this tensor as
44:31
device of idx to create this tensor as well and so there won't be any mismatch
44:33
well and so there won't be any mismatch
44:33
well and so there won't be any mismatch where one tensor is on CPU one is on GPU
44:36
where one tensor is on CPU one is on GPU
44:36
where one tensor is on CPU one is on GPU and uh that you can't combine those but
44:38
and uh that you can't combine those but
44:38
and uh that you can't combine those but here we are um carefully initializing on
44:40
here we are um carefully initializing on
44:41
here we are um carefully initializing on the correct device as indicated by the
44:43
the correct device as indicated by the
44:43
the correct device as indicated by the input to this model so this will autod
44:47
input to this model so this will autod
44:47
input to this model so this will autod detect device for me this will be of
44:49
detect device for me this will be of
44:49
detect device for me this will be of course
44:50
course
44:50
course GPU so using device
44:54
GPU so using device
44:54
GPU so using device Cuda uh but uh you can also run with um
44:58
Cuda uh but uh you can also run with um
44:58
Cuda uh but uh you can also run with um as I mentioned another device and it's
44:59
as I mentioned another device and it's
45:00
as I mentioned another device and it's not going to be too much slower so if I
45:01
not going to be too much slower so if I
45:01
not going to be too much slower so if I override device here
45:03
override device here
45:03
override device here oops if I override device equals
45:07
oops if I override device equals
45:07
oops if I override device equals CPU
45:08
CPU
45:08
CPU then we'll still print Cuda of course
45:11
then we'll still print Cuda of course
45:11
then we'll still print Cuda of course but now we're actually using CPU one 2 3
45:16
but now we're actually using CPU one 2 3
45:16
but now we're actually using CPU one 2 3 4 5 6 okay about 6 seconds and actually
45:21
4 5 6 okay about 6 seconds and actually
45:21
4 5 6 okay about 6 seconds and actually we're not using torch compile and stuff
45:22
we're not using torch compile and stuff
45:22
we're not using torch compile and stuff like that which will speed up everything
45:24
like that which will speed up everything
45:24
like that which will speed up everything a lot faster as well but you can follow
45:27
a lot faster as well but you can follow
45:27
a lot faster as well but you can follow even on a CPU I think to a decent extent
45:30
even on a CPU I think to a decent extent
45:30
even on a CPU I think to a decent extent um so that's note on that okay so I do
45:32
um so that's note on that okay so I do
45:32
um so that's note on that okay so I do want to loop around eventually into what
45:35
want to loop around eventually into what
45:35
want to loop around eventually into what it means to have different devices in
45:36
it means to have different devices in
45:36
it means to have different devices in pytorch and what it is exactly that
45:38
pytorch and what it is exactly that
45:38
pytorch and what it is exactly that pytorch does in the background for you
45:40
pytorch does in the background for you
45:40
pytorch does in the background for you when you do something like module. 2
45:43
when you do something like module. 2
45:43
when you do something like module. 2 device or where you take a torch tensor
45:45
device or where you take a torch tensor
45:45
device or where you take a torch tensor and do A2 device and what exactly
45:48
and do A2 device and what exactly
45:48
and do A2 device and what exactly happens and how that works but for now
45:49
happens and how that works but for now
45:49
happens and how that works but for now I'd like to get to training and I'd like
45:51
I'd like to get to training and I'd like
45:51
I'd like to get to training and I'd like to start training the model and for now
45:53
to start training the model and for now
45:53
to start training the model and for now let's just say the device makes code go
45:55
let's just say the device makes code go
45:55
let's just say the device makes code go fast um and let's go into how we can
45:58
fast um and let's go into how we can
45:58
fast um and let's go into how we can actually train the model so to train the
46:00
actually train the model so to train the
46:00
actually train the model so to train the model we're going to need some data set
46:02
model we're going to need some data set
46:02
model we're going to need some data set and for me the best debugging simplest
46:04
and for me the best debugging simplest
46:04
and for me the best debugging simplest data set that I like to use is the tiny
46:06
data set that I like to use is the tiny
46:06
data set that I like to use is the tiny Shakespeare data set um and it's
46:08
Shakespeare data set um and it's
46:09
Shakespeare data set um and it's available at this URL so you can W get
46:11
available at this URL so you can W get
46:11
available at this URL so you can W get it or you can just search tiny
46:12
it or you can just search tiny
46:12
it or you can just search tiny Shakespeare data
46:13
Shakespeare data
46:13
Shakespeare data set and so um I have in my file system
46:16
set and so um I have in my file system
46:16
set and so um I have in my file system as just LS input.txt
46:18
as just LS input.txt
46:18
as just LS input.txt so I already downloaded it and here I'm
46:22
so I already downloaded it and here I'm
46:22
so I already downloaded it and here I'm reading the data set getting the first
46:23
reading the data set getting the first
46:23
reading the data set getting the first 1,000 characters and printing the first
46:26
1,000 characters and printing the first
46:26
1,000 characters and printing the first 100
46:27
100
46:27
100 now remember that gpt2 has uh roughly a
46:30
now remember that gpt2 has uh roughly a
46:30
now remember that gpt2 has uh roughly a compression ratio the tokenizer has a
46:32
compression ratio the tokenizer has a
46:32
compression ratio the tokenizer has a compression ratio of rly 3 to1 so th000
46:35
compression ratio of rly 3 to1 so th000
46:35
compression ratio of rly 3 to1 so th000 characters is roughly 300 tokens here uh
46:37
characters is roughly 300 tokens here uh
46:37
characters is roughly 300 tokens here uh that will come out of this in the slice
46:39
that will come out of this in the slice
46:39
that will come out of this in the slice that we're currently getting so this is
46:41
that we're currently getting so this is
46:42
that we're currently getting so this is the first few uh
46:44
the first few uh
46:44
the first few uh characters and uh if you want to get a
46:46
characters and uh if you want to get a
46:46
characters and uh if you want to get a few more statistics on this we can do
46:48
few more statistics on this we can do
46:48
few more statistics on this we can do work count on input.txt
46:50
work count on input.txt
46:50
work count on input.txt so we can see that this is uh 40,000
46:53
so we can see that this is uh 40,000
46:53
so we can see that this is uh 40,000 lines about 200,000 words in this data
46:56
lines about 200,000 words in this data
46:56
lines about 200,000 words in this data set and about 1 million bytes in this
46:59
set and about 1 million bytes in this
46:59
set and about 1 million bytes in this file and knowing that this file is only
47:01
file and knowing that this file is only
47:01
file and knowing that this file is only asky characters there's no crazy unic
47:03
asky characters there's no crazy unic
47:03
asky characters there's no crazy unic code here as far as I know and so every
47:05
code here as far as I know and so every
47:05
code here as far as I know and so every asky character is encoded with one bite
47:08
asky character is encoded with one bite
47:08
asky character is encoded with one bite and so this is uh the same number
47:10
and so this is uh the same number
47:10
and so this is uh the same number roughly a million characters inside this
47:12
roughly a million characters inside this
47:12
roughly a million characters inside this data set so that's the data set size uh
47:15
data set so that's the data set size uh
47:15
data set so that's the data set size uh by default very small and minimal data
47:17
by default very small and minimal data
47:17
by default very small and minimal data set for debugging to get us off the
47:19
set for debugging to get us off the
47:19
set for debugging to get us off the ground in order to tokenize this data
47:21
ground in order to tokenize this data
47:21
ground in order to tokenize this data set we're going to get Tik token
47:23
set we're going to get Tik token
47:23
set we're going to get Tik token encoding for gbt2 encode the data uh the
47:27
encoding for gbt2 encode the data uh the
47:27
encoding for gbt2 encode the data uh the first um 1,000 characters and then I'm
47:30
first um 1,000 characters and then I'm
47:30
first um 1,000 characters and then I'm only going to print the first 24 tokens
47:33
only going to print the first 24 tokens
47:33
only going to print the first 24 tokens so these are the tokens as a list of
47:36
so these are the tokens as a list of
47:36
so these are the tokens as a list of integers and if you can read gpt2 tokens
47:38
integers and if you can read gpt2 tokens
47:38
integers and if you can read gpt2 tokens you will see that 198 here you'll
47:40
you will see that 198 here you'll
47:40
you will see that 198 here you'll recognize that as the slashing character
47:42
recognize that as the slashing character
47:42
recognize that as the slashing character so that is a new line and then here for
47:45
so that is a new line and then here for
47:45
so that is a new line and then here for example we have two new lines so that's
47:46
example we have two new lines so that's
47:46
example we have two new lines so that's 198 twice here uh so this is just a
47:49
198 twice here uh so this is just a
47:49
198 twice here uh so this is just a tokenization of the first 24 tokens so
47:52
tokenization of the first 24 tokens so
47:52
tokenization of the first 24 tokens so what we want to do now is we want to
47:54
what we want to do now is we want to
47:54
what we want to do now is we want to actually process these token sequences
47:56
actually process these token sequences
47:56
actually process these token sequences and feed them into a Transformer and in
47:59
and feed them into a Transformer and in
47:59
and feed them into a Transformer and in particular we want them we want to
48:01
particular we want them we want to
48:01
particular we want them we want to rearrange these tokens into this idx
48:05
rearrange these tokens into this idx
48:05
rearrange these tokens into this idx variable that we're going to be feeding
48:06
variable that we're going to be feeding
48:06
variable that we're going to be feeding into the Transformer so we don't want a
48:08
into the Transformer so we don't want a
48:08
into the Transformer so we don't want a single very long onedimensional sequence
48:10
single very long onedimensional sequence
48:10
single very long onedimensional sequence we want an entire batch where each
48:12
we want an entire batch where each
48:12
we want an entire batch where each sequence is up to uh is basically T
48:16
sequence is up to uh is basically T
48:16
sequence is up to uh is basically T tokens and T cannot be larger than the
48:18
tokens and T cannot be larger than the
48:18
tokens and T cannot be larger than the maximum sequence length and then we have
48:21
maximum sequence length and then we have
48:21
maximum sequence length and then we have these t uh tlong uh sequences of tokens
48:24
these t uh tlong uh sequences of tokens
48:25
these t uh tlong uh sequences of tokens and we have B independent examples of
48:27
and we have B independent examples of
48:27
and we have B independent examples of sequences so how can we create a b BYT
48:30
sequences so how can we create a b BYT
48:30
sequences so how can we create a b BYT tensor that we can feed into the forward
48:32
tensor that we can feed into the forward
48:32
tensor that we can feed into the forward out of these onedimensional
48:34
out of these onedimensional
48:34
out of these onedimensional sequences so here's my favorite way to
48:36
sequences so here's my favorite way to
48:36
sequences so here's my favorite way to to achieve this uh so if we take torch
48:39
to achieve this uh so if we take torch
48:39
to achieve this uh so if we take torch and then we create a tensor object out
48:41
and then we create a tensor object out
48:41
and then we create a tensor object out of this list of integers and just the
48:42
of this list of integers and just the
48:42
of this list of integers and just the first 24 tokens my favorite way to do
48:45
first 24 tokens my favorite way to do
48:45
first 24 tokens my favorite way to do this is basically you do a do view of um
48:49
this is basically you do a do view of um
48:49
this is basically you do a do view of um of uh for example 4x6 which multiply to
48:52
of uh for example 4x6 which multiply to
48:52
of uh for example 4x6 which multiply to 24 and so it's just a two-dimensional
48:54
24 and so it's just a two-dimensional
48:54
24 and so it's just a two-dimensional rearrangement of these tokens and you'll
48:56
rearrangement of these tokens and you'll
48:56
rearrangement of these tokens and you'll is that when you view this
48:57
is that when you view this
48:57
is that when you view this onedimensional sequence as
48:58
onedimensional sequence as
48:58
onedimensional sequence as two-dimensional 4x6 here the first six
49:03
two-dimensional 4x6 here the first six
49:03
two-dimensional 4x6 here the first six uh tokens uh up to here end up being the
49:06
uh tokens uh up to here end up being the
49:06
uh tokens uh up to here end up being the first row the next six tokens here end
49:09
first row the next six tokens here end
49:09
first row the next six tokens here end up being the second row and so on and so
49:12
up being the second row and so on and so
49:12
up being the second row and so on and so basically it's just going to stack up
49:14
basically it's just going to stack up
49:14
basically it's just going to stack up this the um every six tokens in this
49:18
this the um every six tokens in this
49:18
this the um every six tokens in this case as independent rows and it creates
49:20
case as independent rows and it creates
49:20
case as independent rows and it creates a batch of tokens in this case and so
49:23
a batch of tokens in this case and so
49:23
a batch of tokens in this case and so for example if we are token 25 in the
49:26
for example if we are token 25 in the
49:26
for example if we are token 25 in the Transformer when we feed this in and
49:28
Transformer when we feed this in and
49:28
Transformer when we feed this in and this becomes the idx this token is going
49:30
this becomes the idx this token is going
49:30
this becomes the idx this token is going to see these three tokens and it's going
49:33
to see these three tokens and it's going
49:33
to see these three tokens and it's going to try to predict that 198 comes
49:35
to try to predict that 198 comes
49:35
to try to predict that 198 comes next so in this way we are able to
49:39
next so in this way we are able to
49:39
next so in this way we are able to create this two-dimensional batch that's
49:41
create this two-dimensional batch that's
49:41
create this two-dimensional batch that's that's quite nice now in terms of the
49:44
that's quite nice now in terms of the
49:44
that's quite nice now in terms of the label that we're going to need for the
49:45
label that we're going to need for the
49:45
label that we're going to need for the Target to calculate the loss function
49:47
Target to calculate the loss function
49:47
Target to calculate the loss function how do we get that well we could write
49:49
how do we get that well we could write
49:49
how do we get that well we could write some code inside the forward pass
49:51
some code inside the forward pass
49:51
some code inside the forward pass because we know that the next uh token
49:53
because we know that the next uh token
49:53
because we know that the next uh token in a sequence which is the label is just
49:55
in a sequence which is the label is just
49:55
in a sequence which is the label is just to the right of us but you'll notice
49:57
to the right of us but you'll notice
49:57
to the right of us but you'll notice that actually we for this token at the
49:59
that actually we for this token at the
49:59
that actually we for this token at the very end 13 we don't actually have the
50:02
very end 13 we don't actually have the
50:02
very end 13 we don't actually have the next correct token because we didn't
50:03
next correct token because we didn't
50:03
next correct token because we didn't load it so uh we actually didn't get
50:07
load it so uh we actually didn't get
50:07
load it so uh we actually didn't get enough information here so I'll show you
50:09
enough information here so I'll show you
50:09
enough information here so I'll show you my favorite way of basically getting
50:11
my favorite way of basically getting
50:11
my favorite way of basically getting these batches and I like to personally
50:14
these batches and I like to personally
50:14
these batches and I like to personally have not just the input to the
50:15
have not just the input to the
50:15
have not just the input to the Transformer which I like to call X but I
50:18
Transformer which I like to call X but I
50:18
Transformer which I like to call X but I also like to create the labels uh tensor
50:21
also like to create the labels uh tensor
50:21
also like to create the labels uh tensor which is of the exact same size as X but
50:24
which is of the exact same size as X but
50:24
which is of the exact same size as X but contains the targets at every single
50:26
contains the targets at every single
50:26
contains the targets at every single position
50:27
position
50:27
position and so here's the way that I like to do
50:28
and so here's the way that I like to do
50:28
and so here's the way that I like to do that I like to make sure that I fetch
50:30
that I like to make sure that I fetch
50:30
that I like to make sure that I fetch plus one uh token because we need the
50:32
plus one uh token because we need the
50:32
plus one uh token because we need the ground Truth for the very last token uh
50:35
ground Truth for the very last token uh
50:35
ground Truth for the very last token uh for
50:36
for
50:36
for 13 and then when we're creating the
50:39
13 and then when we're creating the
50:39
13 and then when we're creating the input we take everything up to the last
50:41
input we take everything up to the last
50:41
input we take everything up to the last token not including and view it as 4x6
50:44
token not including and view it as 4x6
50:44
token not including and view it as 4x6 and when we're creating targets we do
50:47
and when we're creating targets we do
50:47
and when we're creating targets we do the buffer but starting at index one not
50:50
the buffer but starting at index one not
50:50
the buffer but starting at index one not index zero so we're skipping the first
50:52
index zero so we're skipping the first
50:52
index zero so we're skipping the first element and we view it in the exact same
50:54
element and we view it in the exact same
50:54
element and we view it in the exact same size and then when I print this
50:58
size and then when I print this
50:58
size and then when I print this here's what happens where we see that
51:00
here's what happens where we see that
51:00
here's what happens where we see that basically as an example for this token
51:02
basically as an example for this token
51:02
basically as an example for this token 25 its Target was 198 and that's now
51:05
25 its Target was 198 and that's now
51:05
25 its Target was 198 and that's now just stored at the exact same position
51:07
just stored at the exact same position
51:07
just stored at the exact same position in the Target tensor which is 198 and
51:10
in the Target tensor which is 198 and
51:10
in the Target tensor which is 198 and also this last token 13 now has its
51:13
also this last token 13 now has its
51:13
also this last token 13 now has its label which is 198 and that's just
51:16
label which is 198 and that's just
51:16
label which is 198 and that's just because we loaded this plus one here so
51:19
because we loaded this plus one here so
51:19
because we loaded this plus one here so basically this is the way I like to do
51:20
basically this is the way I like to do
51:20
basically this is the way I like to do it you take long sequences you uh view
51:24
it you take long sequences you uh view
51:24
it you take long sequences you uh view them in two- dimensional terms so that
51:26
them in two- dimensional terms so that
51:26
them in two- dimensional terms so that you get batch of time and then we make
51:29
you get batch of time and then we make
51:29
you get batch of time and then we make sure to load one additional token so we
51:31
sure to load one additional token so we
51:31
sure to load one additional token so we basically load a buffer of tokens of B *
51:34
basically load a buffer of tokens of B *
51:34
basically load a buffer of tokens of B * t+ one and then we sort of offset things
51:37
t+ one and then we sort of offset things
51:37
t+ one and then we sort of offset things and view them and then we have two
51:39
and view them and then we have two
51:39
and view them and then we have two tensors one of them is the input to the
51:41
tensors one of them is the input to the
51:41
tensors one of them is the input to the Transformer and the other exactly is the
51:43
Transformer and the other exactly is the
51:43
Transformer and the other exactly is the labels and so let's now reorganize this
51:46
labels and so let's now reorganize this
51:46
labels and so let's now reorganize this code and um create a very simple data
51:50
code and um create a very simple data
51:50
code and um create a very simple data loader object that tries to basically
51:52
loader object that tries to basically
51:52
loader object that tries to basically load these tokens and um feed them to
51:55
load these tokens and um feed them to
51:55
load these tokens and um feed them to the Transformer and calculate the loss
51:57
the Transformer and calculate the loss
51:57
the Transformer and calculate the loss okay so I reshuffled the code here uh
51:59
okay so I reshuffled the code here uh
51:59
okay so I reshuffled the code here uh accordingly so as you can see here I'm
52:01
accordingly so as you can see here I'm
52:01
accordingly so as you can see here I'm temporarily overwriting U to run a CPU
52:05
temporarily overwriting U to run a CPU
52:05
temporarily overwriting U to run a CPU and importing TI token and all of this
52:06
and importing TI token and all of this
52:06
and importing TI token and all of this should look familiar we're loading a
52:08
should look familiar we're loading a
52:08
should look familiar we're loading a th000 characters I'm setting BT to just
52:10
th000 characters I'm setting BT to just
52:10
th000 characters I'm setting BT to just be 4 and 32 right now just because we're
52:13
be 4 and 32 right now just because we're
52:13
be 4 and 32 right now just because we're debugging we just want to have a single
52:14
debugging we just want to have a single
52:15
debugging we just want to have a single batch that's very small and all of this
52:17
batch that's very small and all of this
52:17
batch that's very small and all of this should now look familiar and follows
52:19
should now look familiar and follows
52:19
should now look familiar and follows what we did on the right and then here
52:21
what we did on the right and then here
52:21
what we did on the right and then here we get the we create the model and get
52:24
we get the we create the model and get
52:24
we get the we create the model and get the lits and so so here as you see I
52:28
the lits and so so here as you see I
52:28
the lits and so so here as you see I already ran this only runs in a few
52:29
already ran this only runs in a few
52:30
already ran this only runs in a few seconds but because we have a batch of
52:32
seconds but because we have a batch of
52:32
seconds but because we have a batch of uh 4X 32 our lits are now of size 4X 32x
52:38
uh 4X 32 our lits are now of size 4X 32x
52:38
uh 4X 32 our lits are now of size 4X 32x 50257 so those are the lit for what
52:40
50257 so those are the lit for what
52:40
50257 so those are the lit for what comes next at every position and now we
52:43
comes next at every position and now we
52:43
comes next at every position and now we have the labels which are stored in y so
52:46
have the labels which are stored in y so
52:46
have the labels which are stored in y so now is the time to calculate the loss
52:48
now is the time to calculate the loss
52:48
now is the time to calculate the loss and then do the backward pass and then
52:49
and then do the backward pass and then
52:49
and then do the backward pass and then the optimization so let's first
52:51
the optimization so let's first
52:51
the optimization so let's first calculate the
52:52
calculate the
52:52
calculate the loss okay so to calculate the loss we're
52:55
loss okay so to calculate the loss we're
52:55
loss okay so to calculate the loss we're going to adjust the forward function of
52:56
going to adjust the forward function of
52:56
going to adjust the forward function of this NN module in the model and in
52:59
this NN module in the model and in
52:59
this NN module in the model and in particular we're not just going to be
53:00
particular we're not just going to be
53:00
particular we're not just going to be returning logits but also we're going to
53:02
returning logits but also we're going to
53:02
returning logits but also we're going to return the loss uh and we're going to
53:04
return the loss uh and we're going to
53:04
return the loss uh and we're going to not just pass in the input in thees but
53:06
not just pass in the input in thees but
53:06
not just pass in the input in thees but also the targets uh in y and now we will
53:12
also the targets uh in y and now we will
53:12
also the targets uh in y and now we will print not Lo just. shape anymore we're
53:14
print not Lo just. shape anymore we're
53:14
print not Lo just. shape anymore we're actually going to print the loss
53:14
actually going to print the loss
53:14
actually going to print the loss function and then c. exit of zero so
53:17
function and then c. exit of zero so
53:17
function and then c. exit of zero so that we skip some of the sampling logic
53:20
that we skip some of the sampling logic
53:20
that we skip some of the sampling logic so now let's swing up to the forward
53:21
so now let's swing up to the forward
53:21
so now let's swing up to the forward function which gets called there because
53:25
function which gets called there because
53:25
function which gets called there because now we also have these optional
53:28
now we also have these optional
53:28
now we also have these optional targets and when we get the targets we
53:30
targets and when we get the targets we
53:30
targets and when we get the targets we can also calculate uh the loss and
53:32
can also calculate uh the loss and
53:32
can also calculate uh the loss and remember that we want to basically
53:34
remember that we want to basically
53:34
remember that we want to basically return uh log just loss and loss by
53:36
return uh log just loss and loss by
53:36
return uh log just loss and loss by default is none
53:39
default is none
53:39
default is none but
53:40
but
53:40
but um let's put this here if uh targets is
53:45
um let's put this here if uh targets is
53:45
um let's put this here if uh targets is not none then we want to calculate loss
53:49
not none then we want to calculate loss
53:49
not none then we want to calculate loss and co-pilot is already getting excited
53:51
and co-pilot is already getting excited
53:51
and co-pilot is already getting excited here and calculating the what looks to
53:53
here and calculating the what looks to
53:53
here and calculating the what looks to be correct loss it is using the cross
53:55
be correct loss it is using the cross
53:55
be correct loss it is using the cross entropy loss as is documented here uh so
54:00
entropy loss as is documented here uh so
54:00
entropy loss as is documented here uh so this is a function in pytorch under the
54:03
this is a function in pytorch under the
54:03
this is a function in pytorch under the functional now what is actually
54:05
functional now what is actually
54:05
functional now what is actually happening here because it looks a little
54:06
happening here because it looks a little
54:06
happening here because it looks a little bit scary uh basically uh the F that
54:09
bit scary uh basically uh the F that
54:09
bit scary uh basically uh the F that cross entropy does not like
54:10
cross entropy does not like
54:10
cross entropy does not like multi-dimensional inputs it can't take a
54:12
multi-dimensional inputs it can't take a
54:12
multi-dimensional inputs it can't take a b BYT by vocap size so what's happening
54:15
b BYT by vocap size so what's happening
54:15
b BYT by vocap size so what's happening here is that we are flattening out this
54:17
here is that we are flattening out this
54:17
here is that we are flattening out this three-dimensional tensor into just two
54:19
three-dimensional tensor into just two
54:19
three-dimensional tensor into just two Dimensions the First Dimension is going
54:21
Dimensions the First Dimension is going
54:21
Dimensions the First Dimension is going to be calculated automatically and it's
54:23
to be calculated automatically and it's
54:23
to be calculated automatically and it's going to be B * T and then the last
54:26
going to be B * T and then the last
54:26
going to be B * T and then the last Dimension is vocap size so basically
54:28
Dimension is vocap size so basically
54:28
Dimension is vocap size so basically this is uh flattening out this
54:30
this is uh flattening out this
54:30
this is uh flattening out this three-dimensional tensor of logits to
54:32
three-dimensional tensor of logits to
54:32
three-dimensional tensor of logits to just be two- dimensional B * T all
54:35
just be two- dimensional B * T all
54:35
just be two- dimensional B * T all individual examples and vocap size on uh
54:39
individual examples and vocap size on uh
54:39
individual examples and vocap size on uh in terms of the length of each row and
54:41
in terms of the length of each row and
54:41
in terms of the length of each row and then it's also flattening out the
54:42
then it's also flattening out the
54:42
then it's also flattening out the targets which are also two- dimensional
54:44
targets which are also two- dimensional
54:44
targets which are also two- dimensional at this stage but we're going to just
54:46
at this stage but we're going to just
54:46
at this stage but we're going to just flatten them out so they're just a
54:48
flatten them out so they're just a
54:48
flatten them out so they're just a single tensor of B * T and this can then
54:51
single tensor of B * T and this can then
54:51
single tensor of B * T and this can then pass into cross entropy to calculate a
54:52
pass into cross entropy to calculate a
54:52
pass into cross entropy to calculate a loss which we return so this should
54:55
loss which we return so this should
54:55
loss which we return so this should basically at this point run because this
54:57
basically at this point run because this
54:57
basically at this point run because this is not too complicated
54:59
is not too complicated
54:59
is not too complicated so let's run it and let's see if we
55:03
so let's run it and let's see if we
55:03
so let's run it and let's see if we should be printing the
55:09
loss and here we see that we printed 11
55:12
loss and here we see that we printed 11
55:12
loss and here we see that we printed 11 uh roughly and so
55:16
uh roughly and so
55:16
uh roughly and so um and notice that this is the tensor of
55:18
um and notice that this is the tensor of
55:18
um and notice that this is the tensor of a single element which is this number 11
55:21
a single element which is this number 11
55:21
a single element which is this number 11 now we also want to be able to calculate
55:23
now we also want to be able to calculate
55:23
now we also want to be able to calculate a reasonable uh kind of starting point
55:25
a reasonable uh kind of starting point
55:25
a reasonable uh kind of starting point for a random rationalized Network so we
55:27
for a random rationalized Network so we
55:27
for a random rationalized Network so we covered this in previous videos but our
55:29
covered this in previous videos but our
55:29
covered this in previous videos but our vocabulary size is
55:31
vocabulary size is
55:31
vocabulary size is 50257 at initialization of the network
55:34
50257 at initialization of the network
55:34
50257 at initialization of the network you would hope that um every vocab
55:37
you would hope that um every vocab
55:37
you would hope that um every vocab element is getting roughly a uniform
55:40
element is getting roughly a uniform
55:40
element is getting roughly a uniform probability uh so that we're not
55:42
probability uh so that we're not
55:42
probability uh so that we're not favoring at initialization any token way
55:45
favoring at initialization any token way
55:45
favoring at initialization any token way too much we're not confidently wrong at
55:47
too much we're not confidently wrong at
55:47
too much we're not confidently wrong at initialization so what we're hoping is
55:49
initialization so what we're hoping is
55:49
initialization so what we're hoping is that the probability of any arbitrary
55:51
that the probability of any arbitrary
55:51
that the probability of any arbitrary token is roughly 1 over 50,2 57 and now
55:55
token is roughly 1 over 50,2 57 and now
55:55
token is roughly 1 over 50,2 57 and now we can sanity check the loss because
55:57
we can sanity check the loss because
55:57
we can sanity check the loss because remember that the cross entropy loss is
55:59
remember that the cross entropy loss is
55:59
remember that the cross entropy loss is just basically the negative um log
56:01
just basically the negative um log
56:01
just basically the negative um log likelihood so if we now take this
56:04
likelihood so if we now take this
56:04
likelihood so if we now take this probability and we take it through the
56:06
probability and we take it through the
56:06
probability and we take it through the natural logarithm and then we do the
56:08
natural logarithm and then we do the
56:08
natural logarithm and then we do the negative that is the loss we expect at
56:11
negative that is the loss we expect at
56:11
negative that is the loss we expect at initialization and we covered this in
56:13
initialization and we covered this in
56:13
initialization and we covered this in previous videos so I would expect
56:15
previous videos so I would expect
56:15
previous videos so I would expect something around 10.82 and we're seeing
56:17
something around 10.82 and we're seeing
56:17
something around 10.82 and we're seeing something around 11 so it's not way off
56:20
something around 11 so it's not way off
56:20
something around 11 so it's not way off this is roughly the probability I expect
56:21
this is roughly the probability I expect
56:21
this is roughly the probability I expect at initialization so that tells me that
56:24
at initialization so that tells me that
56:24
at initialization so that tells me that the at initialization or probability
56:26
the at initialization or probability
56:26
the at initialization or probability distribtion is roughly diffused it's a
56:27
distribtion is roughly diffused it's a
56:27
distribtion is roughly diffused it's a good starting point and we can now uh
56:30
good starting point and we can now uh
56:30
good starting point and we can now uh perform the optimization and tell the
56:32
perform the optimization and tell the
56:32
perform the optimization and tell the network which elements you know should
56:34
network which elements you know should
56:34
network which elements you know should follow correctly in what order so at
56:37
follow correctly in what order so at
56:37
follow correctly in what order so at this point we can do a l step backward
56:39
this point we can do a l step backward
56:39
this point we can do a l step backward calculate the gradients and do an
56:40
calculate the gradients and do an
56:40
calculate the gradients and do an optimization so let's get to that okay
56:43
optimization so let's get to that okay
56:43
optimization so let's get to that okay so let's do the optimization now um so
56:46
so let's do the optimization now um so
56:46
so let's do the optimization now um so here we
56:47
here we
56:47
here we have the loss is this is how we get the
56:51
have the loss is this is how we get the
56:51
have the loss is this is how we get the loss but now basically we want a load
56:53
loss but now basically we want a load
56:53
loss but now basically we want a load for Loop here so 4 I in range let's do
56:55
for Loop here so 4 I in range let's do
56:55
for Loop here so 4 I in range let's do 50 steps or something like that uh let's
56:58
50 steps or something like that uh let's
56:58
50 steps or something like that uh let's create an Optimizer object in
57:00
create an Optimizer object in
57:00
create an Optimizer object in pytorch um and so here we are using the
57:04
pytorch um and so here we are using the
57:04
pytorch um and so here we are using the atom um Optimizer which is an
57:07
atom um Optimizer which is an
57:07
atom um Optimizer which is an alternative to the stochastic radian
57:08
alternative to the stochastic radian
57:08
alternative to the stochastic radian descent Optimizer SGD that we were using
57:11
descent Optimizer SGD that we were using
57:11
descent Optimizer SGD that we were using so SGD is a lot simpler atom is a bit
57:13
so SGD is a lot simpler atom is a bit
57:13
so SGD is a lot simpler atom is a bit more involved and I actually
57:14
more involved and I actually
57:14
more involved and I actually specifically like the atom W variation
57:17
specifically like the atom W variation
57:17
specifically like the atom W variation because in my opinion it kind of just
57:19
because in my opinion it kind of just
57:19
because in my opinion it kind of just like fixes a bug um so adom w is a bug
57:22
like fixes a bug um so adom w is a bug
57:22
like fixes a bug um so adom w is a bug fix of atom is what I would say when we
57:25
fix of atom is what I would say when we
57:25
fix of atom is what I would say when we go to the documentation for atom
57:27
go to the documentation for atom
57:27
go to the documentation for atom W oh my
57:29
W oh my
57:29
W oh my gosh we see um that it takes a bunch of
57:32
gosh we see um that it takes a bunch of
57:32
gosh we see um that it takes a bunch of hyper parameters and it's a little bit
57:33
hyper parameters and it's a little bit
57:34
hyper parameters and it's a little bit more complicated than the SGD we were
57:35
more complicated than the SGD we were
57:35
more complicated than the SGD we were looking at before uh because in addition
57:37
looking at before uh because in addition
57:37
looking at before uh because in addition to basically updating the parameters
57:39
to basically updating the parameters
57:39
to basically updating the parameters with the gradient uh scaled by the
57:41
with the gradient uh scaled by the
57:41
with the gradient uh scaled by the Learning rate it keeps these buffers
57:43
Learning rate it keeps these buffers
57:43
Learning rate it keeps these buffers around and it keeps two buffers the m
57:46
around and it keeps two buffers the m
57:46
around and it keeps two buffers the m and the V which it calls the first and
57:48
and the V which it calls the first and
57:48
and the V which it calls the first and the second moment so something that
57:49
the second moment so something that
57:49
the second moment so something that looks a bit like momentum and something
57:51
looks a bit like momentum and something
57:51
looks a bit like momentum and something that looks a bit like RMS prop if you're
57:53
that looks a bit like RMS prop if you're
57:53
that looks a bit like RMS prop if you're familiar with it but you don't have to
57:55
familiar with it but you don't have to
57:55
familiar with it but you don't have to be it's just kind of a normalization
57:57
be it's just kind of a normalization
57:57
be it's just kind of a normalization that happens on each gradient element
57:58
that happens on each gradient element
57:59
that happens on each gradient element individually and speeds up the
58:00
individually and speeds up the
58:00
individually and speeds up the optimization especially for language
58:02
optimization especially for language
58:02
optimization especially for language models but I'm not going to go into the
58:04
models but I'm not going to go into the
58:04
models but I'm not going to go into the detail right here we're going to treat
58:06
detail right here we're going to treat
58:06
detail right here we're going to treat it as a bit of a black box and it just
58:08
it as a bit of a black box and it just
58:08
it as a bit of a black box and it just optimizes um the objective faster than
58:11
optimizes um the objective faster than
58:12
optimizes um the objective faster than SGD which is what we've seen in the
58:13
SGD which is what we've seen in the
58:13
SGD which is what we've seen in the previous lectures so let's use it as a
58:15
previous lectures so let's use it as a
58:15
previous lectures so let's use it as a black box in our case uh create the
58:18
black box in our case uh create the
58:18
black box in our case uh create the optimizer object and
58:21
optimizer object and
58:21
optimizer object and then go through the optimization
58:28
the first thing to always make sure the
58:30
the first thing to always make sure the
58:30
the first thing to always make sure the co-pilot did not forget to zero the
58:32
co-pilot did not forget to zero the
58:32
co-pilot did not forget to zero the gradients so um always remember that you
58:35
gradients so um always remember that you
58:35
gradients so um always remember that you have to start with a zero gradient then
58:38
have to start with a zero gradient then
58:38
have to start with a zero gradient then when you get your loss and you do a DOT
58:39
when you get your loss and you do a DOT
58:39
when you get your loss and you do a DOT backward dot backward adds to gradients
58:42
backward dot backward adds to gradients
58:42
backward dot backward adds to gradients so it deposits gradients it it always
58:44
so it deposits gradients it it always
58:44
so it deposits gradients it it always does a plus equals on whatever the
58:46
does a plus equals on whatever the
58:46
does a plus equals on whatever the gradients are which is why you must set
58:48
gradients are which is why you must set
58:48
gradients are which is why you must set them to zero so this accumulates the
58:50
them to zero so this accumulates the
58:50
them to zero so this accumulates the gradient from this loss and then we call
58:52
gradient from this loss and then we call
58:52
gradient from this loss and then we call the step function on the optimizer to um
58:56
the step function on the optimizer to um
58:56
the step function on the optimizer to um update the parameters and to um decrease
59:00
update the parameters and to um decrease
59:00
update the parameters and to um decrease the
59:00
the
59:00
the loss and then we print a step and the
59:03
loss and then we print a step and the
59:03
loss and then we print a step and the loss do item is used here because loss
59:06
loss do item is used here because loss
59:06
loss do item is used here because loss is a tensor with a single element do
59:08
is a tensor with a single element do
59:08
is a tensor with a single element do item will actually uh convert that to a
59:11
item will actually uh convert that to a
59:11
item will actually uh convert that to a single float and this float will live
59:13
single float and this float will live
59:13
single float and this float will live not will will live on the CPU so this
59:16
not will will live on the CPU so this
59:16
not will will live on the CPU so this gets to some of the internals again of
59:17
gets to some of the internals again of
59:17
gets to some of the internals again of the devices but loss is a is a tensor
59:20
the devices but loss is a is a tensor
59:20
the devices but loss is a is a tensor with a single element and it lifts on
59:22
with a single element and it lifts on
59:22
with a single element and it lifts on GPU for me because I'm using gpus when
59:25
GPU for me because I'm using gpus when
59:25
GPU for me because I'm using gpus when you call item P torch behind the scenes
59:28
you call item P torch behind the scenes
59:28
you call item P torch behind the scenes will take that one-dimensional tensor
59:30
will take that one-dimensional tensor
59:30
will take that one-dimensional tensor ship it back to the CPU uh memory and
59:32
ship it back to the CPU uh memory and
59:32
ship it back to the CPU uh memory and convert it into a float that we can just
59:35
convert it into a float that we can just
59:35
convert it into a float that we can just print so this is the optimization and
59:38
print so this is the optimization and
59:38
print so this is the optimization and this should probably just
59:41
this should probably just
59:42
this should probably just work let's see what
59:45
work let's see what
59:45
work let's see what happens actually sorry let me instead of
59:47
happens actually sorry let me instead of
59:47
happens actually sorry let me instead of using CPU override let me delete that so
59:50
using CPU override let me delete that so
59:50
using CPU override let me delete that so this is a bit faster for me and it runs
59:52
this is a bit faster for me and it runs
59:52
this is a bit faster for me and it runs on Cuda
59:58
oh expected all tensors to be on the
1:00:00
oh expected all tensors to be on the
1:00:00
oh expected all tensors to be on the same device but found at least two
1:00:02
same device but found at least two
1:00:02
same device but found at least two devices Cuda zero and CPU so Cuda zero
1:00:06
devices Cuda zero and CPU so Cuda zero
1:00:06
devices Cuda zero and CPU so Cuda zero is the zeroth GPU because I actually
1:00:07
is the zeroth GPU because I actually
1:00:07
is the zeroth GPU because I actually have eight gpus on this box uh so the
1:00:10
have eight gpus on this box uh so the
1:00:10
have eight gpus on this box uh so the zeroth GPU in my box and CPU and model
1:00:14
zeroth GPU in my box and CPU and model
1:00:14
zeroth GPU in my box and CPU and model we have moved to device but when I was
1:00:17
we have moved to device but when I was
1:00:17
we have moved to device but when I was writing this code I actually introduced
1:00:18
writing this code I actually introduced
1:00:18
writing this code I actually introduced a bug because buff we never moved to
1:00:21
a bug because buff we never moved to
1:00:21
a bug because buff we never moved to device and you have to be careful
1:00:23
device and you have to be careful
1:00:23
device and you have to be careful because you can't just do buff dot two
1:00:25
because you can't just do buff dot two
1:00:25
because you can't just do buff dot two of
1:00:26
of
1:00:26
of device um it's not stateful it doesn't
1:00:29
device um it's not stateful it doesn't
1:00:30
device um it's not stateful it doesn't convert it to be a device it instead uh
1:00:33
convert it to be a device it instead uh
1:00:33
convert it to be a device it instead uh returns pointer to a new memory which is
1:00:35
returns pointer to a new memory which is
1:00:35
returns pointer to a new memory which is on the device so you see how we can just
1:00:37
on the device so you see how we can just
1:00:37
on the device so you see how we can just do model that two a device that does not
1:00:39
do model that two a device that does not
1:00:39
do model that two a device that does not apply to tensors you have to do buff
1:00:42
apply to tensors you have to do buff
1:00:42
apply to tensors you have to do buff equals
1:00:44
equals
1:00:44
equals um b.2 device and then this should work
1:00:49
um b.2 device and then this should work
1:00:49
um b.2 device and then this should work okay so what do we expect to see we
1:00:52
okay so what do we expect to see we
1:00:52
okay so what do we expect to see we expect to see a reasonable loss in the
1:00:53
expect to see a reasonable loss in the
1:00:53
expect to see a reasonable loss in the beginning and then we continue to
1:00:55
beginning and then we continue to
1:00:55
beginning and then we continue to optimize just the single batch and so we
1:00:57
optimize just the single batch and so we
1:00:57
optimize just the single batch and so we want to see that we can overfit this
1:00:58
want to see that we can overfit this
1:00:58
want to see that we can overfit this single batch we can we can crush this
1:01:01
single batch we can we can crush this
1:01:01
single batch we can we can crush this little batch and we can perfectly
1:01:02
little batch and we can perfectly
1:01:02
little batch and we can perfectly predict the indices on just this little
1:01:04
predict the indices on just this little
1:01:04
predict the indices on just this little batch and indeed that is roughly what
1:01:06
batch and indeed that is roughly what
1:01:06
batch and indeed that is roughly what we're seeing here
1:01:08
we're seeing here
1:01:08
we're seeing here so um we started off at roughly 10.82 11
1:01:12
so um we started off at roughly 10.82 11
1:01:12
so um we started off at roughly 10.82 11 in this case and then as we continue
1:01:14
in this case and then as we continue
1:01:14
in this case and then as we continue optimizing on this single batch without
1:01:15
optimizing on this single batch without
1:01:16
optimizing on this single batch without loading new examples we are making sure
1:01:17
loading new examples we are making sure
1:01:17
loading new examples we are making sure that we can overfit a single batch and
1:01:19
that we can overfit a single batch and
1:01:20
that we can overfit a single batch and we are getting to very very low loss so
1:01:21
we are getting to very very low loss so
1:01:21
we are getting to very very low loss so the Transformer is memorizing this
1:01:24
the Transformer is memorizing this
1:01:24
the Transformer is memorizing this single individual batch and one more
1:01:26
single individual batch and one more
1:01:26
single individual batch and one more thing I didn't mention is uh the
1:01:28
thing I didn't mention is uh the
1:01:28
thing I didn't mention is uh the learning rate here is 3 E4 which is a
1:01:30
learning rate here is 3 E4 which is a
1:01:30
learning rate here is 3 E4 which is a pretty good default for most uh
1:01:33
pretty good default for most uh
1:01:33
pretty good default for most uh optimizations that you want to run at a
1:01:35
optimizations that you want to run at a
1:01:35
optimizations that you want to run at a very early debugging stage so this is
1:01:38
very early debugging stage so this is
1:01:38
very early debugging stage so this is our simple inter Loop and uh we are
1:01:41
our simple inter Loop and uh we are
1:01:41
our simple inter Loop and uh we are overfitting a single batch and this
1:01:42
overfitting a single batch and this
1:01:42
overfitting a single batch and this looks good so now what uh what comes
1:01:45
looks good so now what uh what comes
1:01:45
looks good so now what uh what comes next is we don't just want to overfit a
1:01:46
next is we don't just want to overfit a
1:01:46
next is we don't just want to overfit a single batch we actually want to do an
1:01:48
single batch we actually want to do an
1:01:48
single batch we actually want to do an optimization so we actually need to
1:01:50
optimization so we actually need to
1:01:50
optimization so we actually need to iterate these XY batches and create a
1:01:52
iterate these XY batches and create a
1:01:52
iterate these XY batches and create a little data loader uh that makes sure
1:01:54
little data loader uh that makes sure
1:01:54
little data loader uh that makes sure that we're always getting a fresh batch
1:01:56
that we're always getting a fresh batch
1:01:56
that we're always getting a fresh batch and that we're actually optimizing a
1:01:57
and that we're actually optimizing a
1:01:57
and that we're actually optimizing a reasonable objective so let's do that
1:01:59
reasonable objective so let's do that
1:01:59
reasonable objective so let's do that next okay so this is what I came up with
1:02:01
next okay so this is what I came up with
1:02:01
next okay so this is what I came up with and I wrote a little data loader
1:02:03
and I wrote a little data loader
1:02:03
and I wrote a little data loader light um so what this data loader does
1:02:06
light um so what this data loader does
1:02:06
light um so what this data loader does is we're importing the token up here
1:02:08
is we're importing the token up here
1:02:08
is we're importing the token up here we're reading the entire text file from
1:02:10
we're reading the entire text file from
1:02:10
we're reading the entire text file from this single input.txt
1:02:12
this single input.txt
1:02:12
this single input.txt tokenizing it and then we're just
1:02:14
tokenizing it and then we're just
1:02:14
tokenizing it and then we're just printing the number of tokens in total
1:02:17
printing the number of tokens in total
1:02:17
printing the number of tokens in total and the number of batches in a single
1:02:19
and the number of batches in a single
1:02:19
and the number of batches in a single Epoch of iterating over this data set so
1:02:22
Epoch of iterating over this data set so
1:02:22
Epoch of iterating over this data set so how many unique batches do we output
1:02:24
how many unique batches do we output
1:02:24
how many unique batches do we output before we loop back around the beginning
1:02:26
before we loop back around the beginning
1:02:26
before we loop back around the beginning of the document and start reading it
1:02:28
of the document and start reading it
1:02:28
of the document and start reading it again so we start off at position zero
1:02:31
again so we start off at position zero
1:02:31
again so we start off at position zero and then we simply walk the document in
1:02:33
and then we simply walk the document in
1:02:33
and then we simply walk the document in batches of B * T so we take chunks of B
1:02:36
batches of B * T so we take chunks of B
1:02:36
batches of B * T so we take chunks of B * T and then always Advance by B * T and
1:02:40
* T and then always Advance by B * T and
1:02:40
* T and then always Advance by B * T and um it's important to note that we're
1:02:42
um it's important to note that we're
1:02:42
um it's important to note that we're always advancing our position by exactly
1:02:44
always advancing our position by exactly
1:02:44
always advancing our position by exactly B * T but when we're fetching the tokens
1:02:47
B * T but when we're fetching the tokens
1:02:47
B * T but when we're fetching the tokens we're actually fetching from current
1:02:49
we're actually fetching from current
1:02:49
we're actually fetching from current position to B * t + 1 and we need that
1:02:52
position to B * t + 1 and we need that
1:02:52
position to B * t + 1 and we need that plus one because remember uh we need the
1:02:55
plus one because remember uh we need the
1:02:55
plus one because remember uh we need the target token
1:02:56
target token
1:02:56
target token um for the last token in the current
1:02:58
um for the last token in the current
1:02:58
um for the last token in the current batch and so that way we can do um the
1:03:02
batch and so that way we can do um the
1:03:02
batch and so that way we can do um the XY exactly as we did it before and if we
1:03:07
XY exactly as we did it before and if we
1:03:07
XY exactly as we did it before and if we are to um run out of data we'll just
1:03:09
are to um run out of data we'll just
1:03:09
are to um run out of data we'll just loop back around to zero so this is one
1:03:12
loop back around to zero so this is one
1:03:12
loop back around to zero so this is one way to write a very very simple data
1:03:13
way to write a very very simple data
1:03:13
way to write a very very simple data loader um that simply just goes through
1:03:16
loader um that simply just goes through
1:03:16
loader um that simply just goes through the file in chunks and is good enough
1:03:19
the file in chunks and is good enough
1:03:19
the file in chunks and is good enough for us uh for current purposes and we're
1:03:21
for us uh for current purposes and we're
1:03:21
for us uh for current purposes and we're going to complexify it later and now
1:03:24
going to complexify it later and now
1:03:24
going to complexify it later and now we'd like to come back around here and
1:03:26
we'd like to come back around here and
1:03:26
we'd like to come back around here and we'd like to actually use our data
1:03:27
we'd like to actually use our data
1:03:27
we'd like to actually use our data loader so the import Tik token has moved
1:03:29
loader so the import Tik token has moved
1:03:29
loader so the import Tik token has moved up and actually all of this is now
1:03:32
up and actually all of this is now
1:03:32
up and actually all of this is now useless so instead we just want a train
1:03:34
useless so instead we just want a train
1:03:35
useless so instead we just want a train loader for the training data and we want
1:03:38
loader for the training data and we want
1:03:38
loader for the training data and we want to use the same hyper parameters for
1:03:39
to use the same hyper parameters for
1:03:39
to use the same hyper parameters for four so B size was four and time was
1:03:43
four so B size was four and time was
1:03:43
four so B size was four and time was 32 and then here we need to get the XY
1:03:47
32 and then here we need to get the XY
1:03:47
32 and then here we need to get the XY for the current batch so let's see if
1:03:49
for the current batch so let's see if
1:03:49
for the current batch so let's see if copal gets it because this is simple
1:03:51
copal gets it because this is simple
1:03:51
copal gets it because this is simple enough uh so we call the next batch and
1:03:53
enough uh so we call the next batch and
1:03:53
enough uh so we call the next batch and then we um make sure that we have to
1:03:57
then we um make sure that we have to
1:03:57
then we um make sure that we have to move our tensors from CPU to the device
1:04:02
move our tensors from CPU to the device
1:04:02
move our tensors from CPU to the device so here when I converted the tokens
1:04:05
so here when I converted the tokens
1:04:05
so here when I converted the tokens notice that I didn't actually move these
1:04:06
notice that I didn't actually move these
1:04:06
notice that I didn't actually move these tokens to the GPU I left them on CPU
1:04:10
tokens to the GPU I left them on CPU
1:04:10
tokens to the GPU I left them on CPU which is the default um and that's just
1:04:12
which is the default um and that's just
1:04:12
which is the default um and that's just because I'm trying not to waste too much
1:04:14
because I'm trying not to waste too much
1:04:14
because I'm trying not to waste too much memory on the GPU in this case this is a
1:04:16
memory on the GPU in this case this is a
1:04:16
memory on the GPU in this case this is a tiny data set and it would fit uh but
1:04:19
tiny data set and it would fit uh but
1:04:19
tiny data set and it would fit uh but it's fine to just uh ship it to GPU
1:04:21
it's fine to just uh ship it to GPU
1:04:21
it's fine to just uh ship it to GPU right now for for our purposes right now
1:04:24
right now for for our purposes right now
1:04:24
right now for for our purposes right now so we get the next batch we keep the
1:04:25
so we get the next batch we keep the
1:04:26
so we get the next batch we keep the data loader simple CPU class and then
1:04:29
data loader simple CPU class and then
1:04:29
data loader simple CPU class and then here we actually ship it to the GPU and
1:04:31
here we actually ship it to the GPU and
1:04:31
here we actually ship it to the GPU and do all the computation and uh let's see
1:04:34
do all the computation and uh let's see
1:04:34
do all the computation and uh let's see if this runs so python train gbt2 pi and
1:04:39
if this runs so python train gbt2 pi and
1:04:39
if this runs so python train gbt2 pi and what do we expect to see before this
1:04:41
what do we expect to see before this
1:04:41
what do we expect to see before this actually happens what we expect to see
1:04:43
actually happens what we expect to see
1:04:43
actually happens what we expect to see is now we're actually getting the next
1:04:44
is now we're actually getting the next
1:04:44
is now we're actually getting the next batch so we expect to not overfit a
1:04:47
batch so we expect to not overfit a
1:04:47
batch so we expect to not overfit a single batch and so I expect our loss to
1:04:50
single batch and so I expect our loss to
1:04:50
single batch and so I expect our loss to come down but not too much and that's
1:04:54
come down but not too much and that's
1:04:54
come down but not too much and that's because I still expect it to come down
1:04:55
because I still expect it to come down
1:04:55
because I still expect it to come down because in the
1:04:57
because in the
1:04:57
because in the 50257 tokens many of those tokens never
1:05:00
50257 tokens many of those tokens never
1:05:00
50257 tokens many of those tokens never occur in our data set so there are some
1:05:02
occur in our data set so there are some
1:05:02
occur in our data set so there are some very easy gains to be made here in the
1:05:04
very easy gains to be made here in the
1:05:04
very easy gains to be made here in the optimization by for example taking the
1:05:06
optimization by for example taking the
1:05:06
optimization by for example taking the biases of all the loits that never occur
1:05:08
biases of all the loits that never occur
1:05:08
biases of all the loits that never occur and driving them to negative infinity
1:05:11
and driving them to negative infinity
1:05:11
and driving them to negative infinity and that would basically just it's just
1:05:12
and that would basically just it's just
1:05:12
and that would basically just it's just that all of these crazy unic codes or
1:05:14
that all of these crazy unic codes or
1:05:14
that all of these crazy unic codes or different languages those tokens never
1:05:16
different languages those tokens never
1:05:16
different languages those tokens never occur so their probability should be
1:05:17
occur so their probability should be
1:05:17
occur so their probability should be very low and so the gains that we should
1:05:19
very low and so the gains that we should
1:05:19
very low and so the gains that we should be seeing are along the lines of
1:05:22
be seeing are along the lines of
1:05:22
be seeing are along the lines of basically deleting the usage of tokens
1:05:24
basically deleting the usage of tokens
1:05:24
basically deleting the usage of tokens that never occur that's probably most of
1:05:26
that never occur that's probably most of
1:05:26
that never occur that's probably most of the loss gain that we're going to see at
1:05:28
the loss gain that we're going to see at
1:05:28
the loss gain that we're going to see at this scale right now uh but we shouldn't
1:05:30
this scale right now uh but we shouldn't
1:05:30
this scale right now uh but we shouldn't come to a zero uh because um we are only
1:05:35
come to a zero uh because um we are only
1:05:35
come to a zero uh because um we are only doing 50 iterations and I don't think
1:05:37
doing 50 iterations and I don't think
1:05:37
doing 50 iterations and I don't think that's enough to do an eoch right now so
1:05:39
that's enough to do an eoch right now so
1:05:39
that's enough to do an eoch right now so let's see what we
1:05:40
let's see what we
1:05:40
let's see what we got we um we have 338,000
1:05:44
got we um we have 338,000
1:05:44
got we um we have 338,000 tokens which makes sense with our 3:1
1:05:47
tokens which makes sense with our 3:1
1:05:47
tokens which makes sense with our 3:1 compression ratio because there are 1
1:05:48
compression ratio because there are 1
1:05:48
compression ratio because there are 1 million uh characters so one Epoch with
1:05:52
million uh characters so one Epoch with
1:05:52
million uh characters so one Epoch with the current setting of B and T will take
1:05:55
the current setting of B and T will take
1:05:55
the current setting of B and T will take 2, 600 batches and we're only doing 50
1:05:58
2, 600 batches and we're only doing 50
1:05:58
2, 600 batches and we're only doing 50 batches of optimization in
1:06:00
batches of optimization in
1:06:01
batches of optimization in here so we start off in a familiar
1:06:03
here so we start off in a familiar
1:06:03
here so we start off in a familiar territory as expected and then we seem
1:06:05
territory as expected and then we seem
1:06:05
territory as expected and then we seem to come down to about 6.6 so basically
1:06:09
to come down to about 6.6 so basically
1:06:09
to come down to about 6.6 so basically things seem to be working okay right now
1:06:11
things seem to be working okay right now
1:06:11
things seem to be working okay right now with respect to our expectations so
1:06:13
with respect to our expectations so
1:06:13
with respect to our expectations so that's good okay next I want to actually
1:06:15
that's good okay next I want to actually
1:06:16
that's good okay next I want to actually fix a bug that we have in our code um
1:06:18
fix a bug that we have in our code um
1:06:18
fix a bug that we have in our code um it's not a major bug but it is a bug
1:06:20
it's not a major bug but it is a bug
1:06:20
it's not a major bug but it is a bug with respect to how gpt2 training uh
1:06:22
with respect to how gpt2 training uh
1:06:22
with respect to how gpt2 training uh should
1:06:24
should
1:06:24
should happen um
1:06:26
happen um
1:06:26
happen um so the buck is the following we were not
1:06:28
so the buck is the following we were not
1:06:28
so the buck is the following we were not being careful enough when we were
1:06:29
being careful enough when we were
1:06:29
being careful enough when we were loading the weights from hugging face
1:06:31
loading the weights from hugging face
1:06:31
loading the weights from hugging face and we actually missed a little detail
1:06:33
and we actually missed a little detail
1:06:33
and we actually missed a little detail so if we come
1:06:35
so if we come
1:06:35
so if we come here notice that um the shape of these
1:06:38
here notice that um the shape of these
1:06:38
here notice that um the shape of these two tensors is the same so this one here
1:06:42
two tensors is the same so this one here
1:06:42
two tensors is the same so this one here is the token embedding at the bottom of
1:06:44
is the token embedding at the bottom of
1:06:44
is the token embedding at the bottom of the
1:06:45
the
1:06:45
the Transformer right so and this one here
1:06:48
Transformer right so and this one here
1:06:48
Transformer right so and this one here is the language modeling head at the top
1:06:50
is the language modeling head at the top
1:06:50
is the language modeling head at the top of the
1:06:51
of the
1:06:51
of the Transformer and both of these are
1:06:53
Transformer and both of these are
1:06:53
Transformer and both of these are basically two-dimensional tensors and
1:06:55
basically two-dimensional tensors and
1:06:55
basically two-dimensional tensors and they shape is identical so here the
1:06:59
they shape is identical so here the
1:06:59
they shape is identical so here the first one is the output embedding the
1:07:00
first one is the output embedding the
1:07:00
first one is the output embedding the token embedding and the second one is
1:07:02
token embedding and the second one is
1:07:02
token embedding and the second one is this linear layer at the very top the
1:07:04
this linear layer at the very top the
1:07:04
this linear layer at the very top the classifier layer both of them are of
1:07:06
classifier layer both of them are of
1:07:07
classifier layer both of them are of shape
1:07:08
shape
1:07:08
shape 50257 X
1:07:09
50257 X
1:07:09
50257 X 768 um this one here is giving us our
1:07:13
768 um this one here is giving us our
1:07:13
768 um this one here is giving us our token embeddings at the bottom and this
1:07:16
token embeddings at the bottom and this
1:07:16
token embeddings at the bottom and this one here is taking the 768 channels of
1:07:18
one here is taking the 768 channels of
1:07:18
one here is taking the 768 channels of the Transformer and trying to upscale
1:07:21
the Transformer and trying to upscale
1:07:21
the Transformer and trying to upscale that to 50, 257 to get the Lis for the
1:07:24
that to 50, 257 to get the Lis for the
1:07:24
that to 50, 257 to get the Lis for the next token so they're both the same
1:07:26
next token so they're both the same
1:07:27
next token so they're both the same shape but more than that actually if you
1:07:29
shape but more than that actually if you
1:07:29
shape but more than that actually if you look at um comparing their elements um
1:07:33
look at um comparing their elements um
1:07:33
look at um comparing their elements um in pytorch this is an element wise
1:07:35
in pytorch this is an element wise
1:07:35
in pytorch this is an element wise equality so then we use do all and we
1:07:37
equality so then we use do all and we
1:07:37
equality so then we use do all and we see that every single element is
1:07:39
see that every single element is
1:07:39
see that every single element is identical and more than that we see that
1:07:42
identical and more than that we see that
1:07:42
identical and more than that we see that if we actually look at the data pointer
1:07:44
if we actually look at the data pointer
1:07:44
if we actually look at the data pointer uh this is what this is a way in pytorch
1:07:46
uh this is what this is a way in pytorch
1:07:47
uh this is what this is a way in pytorch to get the actual pointer to the uh data
1:07:49
to get the actual pointer to the uh data
1:07:49
to get the actual pointer to the uh data and the storage we see that actually the
1:07:51
and the storage we see that actually the
1:07:51
and the storage we see that actually the pointer is identical so not only are
1:07:53
pointer is identical so not only are
1:07:53
pointer is identical so not only are these two separate tensors that happen
1:07:55
these two separate tensors that happen
1:07:55
these two separate tensors that happen to have the same shape and elements
1:07:57
to have the same shape and elements
1:07:57
to have the same shape and elements they're actually pointing to the
1:07:58
they're actually pointing to the
1:07:58
they're actually pointing to the identical tensor so what's happening
1:08:02
identical tensor so what's happening
1:08:02
identical tensor so what's happening here is that this is a common weight
1:08:03
here is that this is a common weight
1:08:03
here is that this is a common weight tying scheme uh that actually comes from
1:08:06
tying scheme uh that actually comes from
1:08:06
tying scheme uh that actually comes from the original
1:08:08
the original
1:08:08
the original um from the original attention is all
1:08:10
um from the original attention is all
1:08:10
um from the original attention is all you need paper and actually even the
1:08:12
you need paper and actually even the
1:08:12
you need paper and actually even the reference before it so if we come
1:08:16
reference before it so if we come
1:08:16
reference before it so if we come here
1:08:19
um eddings and softmax in the attention
1:08:22
um eddings and softmax in the attention
1:08:22
um eddings and softmax in the attention is all you need paper they mentioned
1:08:24
is all you need paper they mentioned
1:08:24
is all you need paper they mentioned that in our model we shared the same
1:08:26
that in our model we shared the same
1:08:26
that in our model we shared the same weight Matrix between the two embedding
1:08:28
weight Matrix between the two embedding
1:08:28
weight Matrix between the two embedding layers and the pre softmax linear
1:08:30
layers and the pre softmax linear
1:08:30
layers and the pre softmax linear transformation similar to 30 um so this
1:08:34
transformation similar to 30 um so this
1:08:34
transformation similar to 30 um so this is an awkward way to phrase that these
1:08:36
is an awkward way to phrase that these
1:08:36
is an awkward way to phrase that these two are shared and they're tied and
1:08:38
two are shared and they're tied and
1:08:38
two are shared and they're tied and they're the same Matrix and the 30
1:08:40
they're the same Matrix and the 30
1:08:40
they're the same Matrix and the 30 reference is this
1:08:42
reference is this
1:08:42
reference is this paper um so this came out in
1:08:45
paper um so this came out in
1:08:45
paper um so this came out in 2017 and you can read the full paper but
1:08:47
2017 and you can read the full paper but
1:08:47
2017 and you can read the full paper but basically it argues for this weight
1:08:49
basically it argues for this weight
1:08:49
basically it argues for this weight tying scheme and I think intuitively the
1:08:53
tying scheme and I think intuitively the
1:08:53
tying scheme and I think intuitively the idea for why you might want to do this
1:08:54
idea for why you might want to do this
1:08:54
idea for why you might want to do this comes from from this paragraph here and
1:08:57
comes from from this paragraph here and
1:08:58
comes from from this paragraph here and basically you you can observe
1:09:01
basically you you can observe
1:09:01
basically you you can observe that um you actually want these two
1:09:04
that um you actually want these two
1:09:04
that um you actually want these two matrices to behave similar in the
1:09:07
matrices to behave similar in the
1:09:07
matrices to behave similar in the following sense if two tokens are very
1:09:10
following sense if two tokens are very
1:09:10
following sense if two tokens are very similar semantically like maybe one of
1:09:12
similar semantically like maybe one of
1:09:12
similar semantically like maybe one of them is all lowercase and the other one
1:09:14
them is all lowercase and the other one
1:09:14
them is all lowercase and the other one is all uppercase or it's the same token
1:09:16
is all uppercase or it's the same token
1:09:16
is all uppercase or it's the same token in a different language or something
1:09:17
in a different language or something
1:09:17
in a different language or something like that if you have similarity between
1:09:19
like that if you have similarity between
1:09:19
like that if you have similarity between two tokens presumably you would expect
1:09:21
two tokens presumably you would expect
1:09:21
two tokens presumably you would expect that they are uh nearby in the token
1:09:23
that they are uh nearby in the token
1:09:23
that they are uh nearby in the token embedding space but in the exact same
1:09:26
embedding space but in the exact same
1:09:26
embedding space but in the exact same way you'd expect that if you have two
1:09:27
way you'd expect that if you have two
1:09:27
way you'd expect that if you have two tokens that are similar semantically
1:09:30
tokens that are similar semantically
1:09:30
tokens that are similar semantically you'd expect them to get the same
1:09:32
you'd expect them to get the same
1:09:32
you'd expect them to get the same probabilities at the output of a
1:09:33
probabilities at the output of a
1:09:33
probabilities at the output of a transformer because they are
1:09:35
transformer because they are
1:09:35
transformer because they are semantically similar and so both
1:09:39
semantically similar and so both
1:09:39
semantically similar and so both positions in the Transformer at the very
1:09:41
positions in the Transformer at the very
1:09:41
positions in the Transformer at the very bottom and at the top have this property
1:09:43
bottom and at the top have this property
1:09:43
bottom and at the top have this property that similar tokens should have similar
1:09:46
that similar tokens should have similar
1:09:46
that similar tokens should have similar embeddings or similar weights and so
1:09:49
embeddings or similar weights and so
1:09:49
embeddings or similar weights and so this is what motivates their exploration
1:09:51
this is what motivates their exploration
1:09:51
this is what motivates their exploration here and they they kind of you know I
1:09:53
here and they they kind of you know I
1:09:53
here and they they kind of you know I don't want to go through the entire
1:09:54
don't want to go through the entire
1:09:54
don't want to go through the entire paper and and uh you can go through it
1:09:57
paper and and uh you can go through it
1:09:57
paper and and uh you can go through it but this is what they observe they also
1:09:59
but this is what they observe they also
1:09:59
but this is what they observe they also observe that if you look at the output
1:10:00
observe that if you look at the output
1:10:00
observe that if you look at the output embeddings they also behave like word
1:10:02
embeddings they also behave like word
1:10:02
embeddings they also behave like word embeddings um if you um if you just kind
1:10:06
embeddings um if you um if you just kind
1:10:06
embeddings um if you um if you just kind of try to use those weights as word
1:10:07
of try to use those weights as word
1:10:08
of try to use those weights as word embeddings um so they kind of observe
1:10:10
embeddings um so they kind of observe
1:10:10
embeddings um so they kind of observe this similarity they try to tie them and
1:10:12
this similarity they try to tie them and
1:10:13
this similarity they try to tie them and they observe that they can get much
1:10:14
they observe that they can get much
1:10:14
they observe that they can get much better performance in that way and so
1:10:17
better performance in that way and so
1:10:17
better performance in that way and so this was adopted and the attention is
1:10:18
this was adopted and the attention is
1:10:18
this was adopted and the attention is all need paper and then it was used
1:10:20
all need paper and then it was used
1:10:20
all need paper and then it was used again in gpt2 as well
1:10:24
again in gpt2 as well
1:10:24
again in gpt2 as well so I couldn't find it in the
1:10:26
so I couldn't find it in the
1:10:26
so I couldn't find it in the Transformers implementation I'm not sure
1:10:28
Transformers implementation I'm not sure
1:10:28
Transformers implementation I'm not sure where they tie those embeddings but I
1:10:30
where they tie those embeddings but I
1:10:30
where they tie those embeddings but I can find it in the original gpt2 code U
1:10:34
can find it in the original gpt2 code U
1:10:34
can find it in the original gpt2 code U introduced by open aai so this is um
1:10:36
introduced by open aai so this is um
1:10:36
introduced by open aai so this is um openai gpt2 Source model and here where
1:10:40
openai gpt2 Source model and here where
1:10:40
openai gpt2 Source model and here where they are forwarding this model and this
1:10:41
they are forwarding this model and this
1:10:41
they are forwarding this model and this is in tensorflow but uh that's okay we
1:10:44
is in tensorflow but uh that's okay we
1:10:44
is in tensorflow but uh that's okay we see that they get the wte token
1:10:46
see that they get the wte token
1:10:46
see that they get the wte token embeddings and then here is the incoder
1:10:50
embeddings and then here is the incoder
1:10:50
embeddings and then here is the incoder of the token embeddings and the
1:10:52
of the token embeddings and the
1:10:52
of the token embeddings and the position and then here at the bottom
1:10:54
position and then here at the bottom
1:10:54
position and then here at the bottom they Ed the WT again to do the lits so
1:10:58
they Ed the WT again to do the lits so
1:10:58
they Ed the WT again to do the lits so when they get the loits it's a math Mo
1:11:00
when they get the loits it's a math Mo
1:11:00
when they get the loits it's a math Mo of uh this output from the Transformer
1:11:02
of uh this output from the Transformer
1:11:02
of uh this output from the Transformer and the wte tensor is
1:11:05
and the wte tensor is
1:11:05
and the wte tensor is reused um and so the wte tensor
1:11:08
reused um and so the wte tensor
1:11:08
reused um and so the wte tensor basically is used twice on the bottom of
1:11:10
basically is used twice on the bottom of
1:11:10
basically is used twice on the bottom of the Transformer and on the top of the
1:11:12
the Transformer and on the top of the
1:11:12
the Transformer and on the top of the Transformer and in the backward pass
1:11:14
Transformer and in the backward pass
1:11:14
Transformer and in the backward pass we'll get gradients contributions from
1:11:17
we'll get gradients contributions from
1:11:17
we'll get gradients contributions from both branches right and these gradients
1:11:18
both branches right and these gradients
1:11:19
both branches right and these gradients will add up um on the wte tensor um so
1:11:23
will add up um on the wte tensor um so
1:11:23
will add up um on the wte tensor um so we'll get a contribution from the
1:11:24
we'll get a contribution from the
1:11:24
we'll get a contribution from the classifier list
1:11:25
classifier list
1:11:25
classifier list and then at the very end of the
1:11:26
and then at the very end of the
1:11:27
and then at the very end of the Transformer we'll get a contribution at
1:11:28
Transformer we'll get a contribution at
1:11:28
Transformer we'll get a contribution at the at the bottom of it float floating
1:11:31
the at the bottom of it float floating
1:11:31
the at the bottom of it float floating again into the wte uh tensor so we want
1:11:35
again into the wte uh tensor so we want
1:11:35
again into the wte uh tensor so we want to we are currently not sharing WT and
1:11:38
to we are currently not sharing WT and
1:11:38
to we are currently not sharing WT and our code but we want to do
1:11:40
our code but we want to do
1:11:40
our code but we want to do that um
1:11:44
that um
1:11:44
that um so weight sharing scheme um and one way
1:11:48
so weight sharing scheme um and one way
1:11:48
so weight sharing scheme um and one way to do this let's see if goil gets it oh
1:11:50
to do this let's see if goil gets it oh
1:11:50
to do this let's see if goil gets it oh it does okay uh so this is one way to do
1:11:54
it does okay uh so this is one way to do
1:11:54
it does okay uh so this is one way to do it
1:11:56
it
1:11:56
it uh
1:11:56
uh
1:11:56
uh basically relatively straightforward
1:11:58
basically relatively straightforward
1:11:59
basically relatively straightforward what we're doing here is we're taking
1:12:00
what we're doing here is we're taking
1:12:00
what we're doing here is we're taking the wte do weight and we're simply uh
1:12:04
the wte do weight and we're simply uh
1:12:04
the wte do weight and we're simply uh redirecting it to point to the LM head
1:12:08
redirecting it to point to the LM head
1:12:08
redirecting it to point to the LM head so um this basically copies the data
1:12:11
so um this basically copies the data
1:12:11
so um this basically copies the data pointer right it copies the reference
1:12:14
pointer right it copies the reference
1:12:14
pointer right it copies the reference and now the wte weight becomes orphaned
1:12:17
and now the wte weight becomes orphaned
1:12:17
and now the wte weight becomes orphaned uh the old value of it and uh pytorch
1:12:20
uh the old value of it and uh pytorch
1:12:20
uh the old value of it and uh pytorch will clean it up python will clean it up
1:12:23
will clean it up python will clean it up
1:12:23
will clean it up python will clean it up and so we are only left with a single
1:12:26
and so we are only left with a single
1:12:26
and so we are only left with a single tensor and it's going to be used twice
1:12:28
tensor and it's going to be used twice
1:12:28
tensor and it's going to be used twice in the forward pass and uh this is to my
1:12:31
in the forward pass and uh this is to my
1:12:31
in the forward pass and uh this is to my knowledge all that's required so we
1:12:34
knowledge all that's required so we
1:12:34
knowledge all that's required so we should be able to use this and this
1:12:36
should be able to use this and this
1:12:36
should be able to use this and this should probably train uh we're just
1:12:39
should probably train uh we're just
1:12:39
should probably train uh we're just going to basically be using this exact
1:12:40
going to basically be using this exact
1:12:40
going to basically be using this exact same sensor twice and
1:12:43
same sensor twice and
1:12:44
same sensor twice and um we weren't being careful with
1:12:46
um we weren't being careful with
1:12:46
um we weren't being careful with tracking the likelihoods but uh
1:12:48
tracking the likelihoods but uh
1:12:48
tracking the likelihoods but uh according to the paper and according to
1:12:49
according to the paper and according to
1:12:50
according to the paper and according to the results you'd actually expect
1:12:50
the results you'd actually expect
1:12:51
the results you'd actually expect slightly better results doing this and
1:12:53
slightly better results doing this and
1:12:53
slightly better results doing this and in addition to that one other reason
1:12:54
in addition to that one other reason
1:12:54
in addition to that one other reason that this is very very nice for us is
1:12:57
that this is very very nice for us is
1:12:57
that this is very very nice for us is that this is a ton of parameters right
1:12:59
that this is a ton of parameters right
1:12:59
that this is a ton of parameters right uh what is the size here it's 768 *
1:13:03
uh what is the size here it's 768 *
1:13:03
uh what is the size here it's 768 * 50257 so This Is 40 million parameters
1:13:07
50257 so This Is 40 million parameters
1:13:07
50257 so This Is 40 million parameters and this is a 124 million parameter
1:13:09
and this is a 124 million parameter
1:13:09
and this is a 124 million parameter model so 40 divide 124 so this is like
1:13:12
model so 40 divide 124 so this is like
1:13:12
model so 40 divide 124 so this is like 30% of the parameters are being saved
1:13:15
30% of the parameters are being saved
1:13:15
30% of the parameters are being saved using this weight time scheme and so
1:13:18
using this weight time scheme and so
1:13:18
using this weight time scheme and so this might be one of the reasons that
1:13:19
this might be one of the reasons that
1:13:20
this might be one of the reasons that this is working slightly better if
1:13:21
this is working slightly better if
1:13:21
this is working slightly better if you're not training the model long
1:13:22
you're not training the model long
1:13:22
you're not training the model long enough because of the weight tying uh
1:13:25
enough because of the weight tying uh
1:13:25
enough because of the weight tying uh you don't have to train as many
1:13:26
you don't have to train as many
1:13:26
you don't have to train as many parameters and so you become more
1:13:27
parameters and so you become more
1:13:27
parameters and so you become more efficient um in terms of the training
1:13:30
efficient um in terms of the training
1:13:30
efficient um in terms of the training process uh because you have fewer
1:13:32
process uh because you have fewer
1:13:32
process uh because you have fewer parameters and you're putting in this
1:13:34
parameters and you're putting in this
1:13:34
parameters and you're putting in this inductive bias that these two embeddings
1:13:36
inductive bias that these two embeddings
1:13:36
inductive bias that these two embeddings should share similarities between tokens
1:13:40
should share similarities between tokens
1:13:40
should share similarities between tokens so this is the way time scheme and we've
1:13:42
so this is the way time scheme and we've
1:13:42
so this is the way time scheme and we've saved a ton of parameters and we expect
1:13:44
saved a ton of parameters and we expect
1:13:44
saved a ton of parameters and we expect our model to work slightly better
1:13:45
our model to work slightly better
1:13:45
our model to work slightly better because of the scheme okay next I would
1:13:47
because of the scheme okay next I would
1:13:47
because of the scheme okay next I would like us to be a bit more careful with
1:13:49
like us to be a bit more careful with
1:13:49
like us to be a bit more careful with the initialization and to try to follow
1:13:50
the initialization and to try to follow
1:13:50
the initialization and to try to follow the way gpt2 initialized their model now
1:13:53
the way gpt2 initialized their model now
1:13:54
the way gpt2 initialized their model now unfortunately the gpt2 paper and the
1:13:55
unfortunately the gpt2 paper and the
1:13:55
unfortunately the gpt2 paper and the gpt3 paper are not very explicit about
1:13:58
gpt3 paper are not very explicit about
1:13:58
gpt3 paper are not very explicit about initialization so we kind of have to
1:14:00
initialization so we kind of have to
1:14:00
initialization so we kind of have to read between the lines uh and instead of
1:14:02
read between the lines uh and instead of
1:14:02
read between the lines uh and instead of going to the paper which is quite vague
1:14:04
going to the paper which is quite vague
1:14:04
going to the paper which is quite vague um there's a bit of information in the
1:14:07
um there's a bit of information in the
1:14:07
um there's a bit of information in the code that open I released so when we go
1:14:09
code that open I released so when we go
1:14:09
code that open I released so when we go to the model.py we see that when they
1:14:11
to the model.py we see that when they
1:14:11
to the model.py we see that when they initialize their weights they are using
1:14:13
initialize their weights they are using
1:14:13
initialize their weights they are using the standard deviation of
1:14:15
the standard deviation of
1:14:15
the standard deviation of 0.02 and that's how they they so this is
1:14:19
0.02 and that's how they they so this is
1:14:19
0.02 and that's how they they so this is a normal distribution for the weights
1:14:21
a normal distribution for the weights
1:14:21
a normal distribution for the weights and the standard deviation is
1:14:23
and the standard deviation is
1:14:23
and the standard deviation is 0.02 for the bias they initialize that
1:14:25
0.02 for the bias they initialize that
1:14:25
0.02 for the bias they initialize that with
1:14:26
with
1:14:26
with zero and then when we scroll down
1:14:30
zero and then when we scroll down
1:14:30
zero and then when we scroll down here why is this not scrolling
1:14:33
here why is this not scrolling
1:14:33
here why is this not scrolling um the token embeddings are initialized
1:14:36
um the token embeddings are initialized
1:14:36
um the token embeddings are initialized at
1:14:37
at
1:14:37
at 0.02 and position embeddings at 0.01 for
1:14:40
0.02 and position embeddings at 0.01 for
1:14:40
0.02 and position embeddings at 0.01 for some reason so those are the
1:14:42
some reason so those are the
1:14:42
some reason so those are the initializations and we'd like to mirror
1:14:44
initializations and we'd like to mirror
1:14:44
initializations and we'd like to mirror that in
1:14:45
that in
1:14:45
that in gpt2 uh in our module here so here's a
1:14:48
gpt2 uh in our module here so here's a
1:14:48
gpt2 uh in our module here so here's a snippet of code that I sort of came up
1:14:50
snippet of code that I sort of came up
1:14:50
snippet of code that I sort of came up with very
1:14:52
with very
1:14:52
with very quickly so what's happening here is at
1:14:55
quickly so what's happening here is at
1:14:55
quickly so what's happening here is at the end of our initializer for the GPT
1:14:57
the end of our initializer for the GPT
1:14:57
the end of our initializer for the GPT module we're calling the apply function
1:14:59
module we're calling the apply function
1:14:59
module we're calling the apply function of NN module and that iterates all the
1:15:02
of NN module and that iterates all the
1:15:02
of NN module and that iterates all the sub modules of this module and uh
1:15:05
sub modules of this module and uh
1:15:05
sub modules of this module and uh applies in it weights function on them
1:15:08
applies in it weights function on them
1:15:08
applies in it weights function on them and so what's happening here is that
1:15:10
and so what's happening here is that
1:15:11
and so what's happening here is that we're in we're iterating all the modules
1:15:13
we're in we're iterating all the modules
1:15:13
we're in we're iterating all the modules here and if they are an nn. linear
1:15:16
here and if they are an nn. linear
1:15:16
here and if they are an nn. linear module then we're going to make sure to
1:15:17
module then we're going to make sure to
1:15:17
module then we're going to make sure to initialize the weight using a normal
1:15:19
initialize the weight using a normal
1:15:19
initialize the weight using a normal with the standard deviation of
1:15:21
with the standard deviation of
1:15:21
with the standard deviation of 0.02 if there's a bias in this layer we
1:15:23
0.02 if there's a bias in this layer we
1:15:24
0.02 if there's a bias in this layer we will make sure to initialize that to
1:15:25
will make sure to initialize that to
1:15:25
will make sure to initialize that to zero note that zero initialization for
1:15:28
zero note that zero initialization for
1:15:28
zero note that zero initialization for the bias is not actually the pyto
1:15:29
the bias is not actually the pyto
1:15:29
the bias is not actually the pyto default um by default the bias here is
1:15:33
default um by default the bias here is
1:15:33
default um by default the bias here is initialized with a uniform so uh that's
1:15:36
initialized with a uniform so uh that's
1:15:36
initialized with a uniform so uh that's interesting so we make sure to use zero
1:15:38
interesting so we make sure to use zero
1:15:38
interesting so we make sure to use zero and for the embedding we're just going
1:15:39
and for the embedding we're just going
1:15:40
and for the embedding we're just going to use 0.02 and um keep it the same um
1:15:43
to use 0.02 and um keep it the same um
1:15:43
to use 0.02 and um keep it the same um so we're not going to change it to 0.01
1:15:45
so we're not going to change it to 0.01
1:15:45
so we're not going to change it to 0.01 for positional because it's about the
1:15:47
for positional because it's about the
1:15:47
for positional because it's about the same and then if you look through our
1:15:49
same and then if you look through our
1:15:49
same and then if you look through our model the only other layer that requires
1:15:51
model the only other layer that requires
1:15:51
model the only other layer that requires initialization and that has parameters
1:15:53
initialization and that has parameters
1:15:53
initialization and that has parameters is the layer norm and the fighter defer
1:15:55
is the layer norm and the fighter defer
1:15:55
is the layer norm and the fighter defer initialization sets the scale in the
1:15:57
initialization sets the scale in the
1:15:57
initialization sets the scale in the layer Norm to be one and the offset in
1:15:59
layer Norm to be one and the offset in
1:16:00
layer Norm to be one and the offset in the layer Norm to be zero so that's
1:16:01
the layer Norm to be zero so that's
1:16:01
the layer Norm to be zero so that's exactly what we want and so we're just
1:16:03
exactly what we want and so we're just
1:16:03
exactly what we want and so we're just going to uh keep it that way and so this
1:16:06
going to uh keep it that way and so this
1:16:06
going to uh keep it that way and so this is the default initialization if we are
1:16:09
is the default initialization if we are
1:16:09
is the default initialization if we are following the um where is it the uh gpt2
1:16:14
following the um where is it the uh gpt2
1:16:14
following the um where is it the uh gpt2 uh source code that they released I
1:16:17
uh source code that they released I
1:16:17
uh source code that they released I would like to point out by the way that
1:16:19
would like to point out by the way that
1:16:19
would like to point out by the way that um typically the standard deviation here
1:16:21
um typically the standard deviation here
1:16:21
um typically the standard deviation here on this initialization if you follow the
1:16:23
on this initialization if you follow the
1:16:23
on this initialization if you follow the Javier initialization would be one of
1:16:24
Javier initialization would be one of
1:16:24
Javier initialization would be one of over the square root of the number of
1:16:27
over the square root of the number of
1:16:27
over the square root of the number of features that are incoming into this
1:16:28
features that are incoming into this
1:16:28
features that are incoming into this layer but if you'll notice actually 0.02
1:16:32
layer but if you'll notice actually 0.02
1:16:32
layer but if you'll notice actually 0.02 is basically consistent with that
1:16:34
is basically consistent with that
1:16:34
is basically consistent with that because the the model sizes inside these
1:16:36
because the the model sizes inside these
1:16:36
because the the model sizes inside these Transformers for gpt2 are roughly 768
1:16:39
Transformers for gpt2 are roughly 768
1:16:39
Transformers for gpt2 are roughly 768 1600 Etc so 1 over the square root of
1:16:41
1600 Etc so 1 over the square root of
1:16:41
1600 Etc so 1 over the square root of for example 768 gives us
1:16:44
for example 768 gives us
1:16:44
for example 768 gives us 0.03 if we plug in 600 1,600 we get
1:16:49
0.03 if we plug in 600 1,600 we get
1:16:49
0.03 if we plug in 600 1,600 we get 0.02 if we plug in three times that
1:16:52
0.02 if we plug in three times that
1:16:52
0.02 if we plug in three times that 0.014 Etc so basically 0.02 is roughly
1:16:56
0.014 Etc so basically 0.02 is roughly
1:16:56
0.014 Etc so basically 0.02 is roughly in the vicinity of reasonable values for
1:16:59
in the vicinity of reasonable values for
1:16:59
in the vicinity of reasonable values for the for um for these initializations
1:17:02
the for um for these initializations
1:17:02
the for um for these initializations anyway so so it's not uh completely
1:17:05
anyway so so it's not uh completely
1:17:05
anyway so so it's not uh completely crazy to be hard coding 0.02 here uh but
1:17:08
crazy to be hard coding 0.02 here uh but
1:17:08
crazy to be hard coding 0.02 here uh but you'd like typically uh some something
1:17:11
you'd like typically uh some something
1:17:11
you'd like typically uh some something that grows with the model size instead
1:17:13
that grows with the model size instead
1:17:13
that grows with the model size instead but we will keep this because that is
1:17:15
but we will keep this because that is
1:17:15
but we will keep this because that is the gpt2 initialization per their source
1:17:17
the gpt2 initialization per their source
1:17:17
the gpt2 initialization per their source code but we are not fully done yet on
1:17:19
code but we are not fully done yet on
1:17:19
code but we are not fully done yet on initialization because there's one more
1:17:20
initialization because there's one more
1:17:20
initialization because there's one more caveat here so
1:17:23
caveat here so
1:17:23
caveat here so here a mod initialization which accounts
1:17:26
here a mod initialization which accounts
1:17:26
here a mod initialization which accounts for the accumulation on the residual
1:17:27
for the accumulation on the residual
1:17:27
for the accumulation on the residual path with model depth is used we scale
1:17:30
path with model depth is used we scale
1:17:30
path with model depth is used we scale the weight of residual layers of
1:17:31
the weight of residual layers of
1:17:31
the weight of residual layers of initialization by factor of one over squ
1:17:33
initialization by factor of one over squ
1:17:33
initialization by factor of one over squ of n where n is the number of residual
1:17:35
of n where n is the number of residual
1:17:35
of n where n is the number of residual layers so this is what gbt2 paper says
1:17:38
layers so this is what gbt2 paper says
1:17:38
layers so this is what gbt2 paper says so we have not implemented that yet and
1:17:40
so we have not implemented that yet and
1:17:41
so we have not implemented that yet and uh we can do so now now I'd like to
1:17:43
uh we can do so now now I'd like to
1:17:43
uh we can do so now now I'd like to actually kind of like motivate a little
1:17:44
actually kind of like motivate a little
1:17:44
actually kind of like motivate a little bit what they mean here I think um so
1:17:47
bit what they mean here I think um so
1:17:47
bit what they mean here I think um so here's roughly what they
1:17:49
here's roughly what they
1:17:49
here's roughly what they mean if you start out with zeros in your
1:17:52
mean if you start out with zeros in your
1:17:52
mean if you start out with zeros in your residual stream remember that each
1:17:54
residual stream remember that each
1:17:54
residual stream remember that each residual stream is a is of this form
1:17:57
residual stream is a is of this form
1:17:57
residual stream is a is of this form where we continue adding to it X is X
1:18:00
where we continue adding to it X is X
1:18:00
where we continue adding to it X is X plus something some kind of contribution
1:18:02
plus something some kind of contribution
1:18:02
plus something some kind of contribution so every single block of the residual uh
1:18:05
so every single block of the residual uh
1:18:05
so every single block of the residual uh Network contributes some uh amount and
1:18:09
Network contributes some uh amount and
1:18:09
Network contributes some uh amount and it gets added and so what ends up
1:18:11
it gets added and so what ends up
1:18:11
it gets added and so what ends up happening is that the variance of the
1:18:15
happening is that the variance of the
1:18:15
happening is that the variance of the activations in the residual stream grows
1:18:17
activations in the residual stream grows
1:18:18
activations in the residual stream grows so here's a small example if we start at
1:18:19
so here's a small example if we start at
1:18:19
so here's a small example if we start at zero and then we for 100 times uh we
1:18:23
zero and then we for 100 times uh we
1:18:23
zero and then we for 100 times uh we have sort of this residual stream of of
1:18:25
have sort of this residual stream of of
1:18:25
have sort of this residual stream of of 768 uh zeros and then 100 times we add
1:18:30
768 uh zeros and then 100 times we add
1:18:30
768 uh zeros and then 100 times we add um random which is a normal distribution
1:18:33
um random which is a normal distribution
1:18:33
um random which is a normal distribution zero mean one standard deviation if we
1:18:36
zero mean one standard deviation if we
1:18:36
zero mean one standard deviation if we add to it then by the end the residual
1:18:37
add to it then by the end the residual
1:18:37
add to it then by the end the residual stream has grown to have standard
1:18:39
stream has grown to have standard
1:18:39
stream has grown to have standard deviation of 10 and that's just because
1:18:42
deviation of 10 and that's just because
1:18:42
deviation of 10 and that's just because um we're always adding um these numbers
1:18:47
um we're always adding um these numbers
1:18:47
um we're always adding um these numbers and so this scaling factor that they use
1:18:50
and so this scaling factor that they use
1:18:50
and so this scaling factor that they use here exactly compensates for that growth
1:18:53
here exactly compensates for that growth
1:18:53
here exactly compensates for that growth so if we take n and we basically um
1:18:57
so if we take n and we basically um
1:18:57
so if we take n and we basically um scale down every one of these
1:18:58
scale down every one of these
1:18:59
scale down every one of these contributions into the residual stream
1:19:00
contributions into the residual stream
1:19:00
contributions into the residual stream by one over theare Ro of n so 1 over
1:19:03
by one over theare Ro of n so 1 over
1:19:03
by one over theare Ro of n so 1 over theun of n is n to the 0.5
1:19:07
theun of n is n to the 0.5
1:19:07
theun of n is n to the 0.5 right because n the5 is the square root
1:19:11
right because n the5 is the square root
1:19:11
right because n the5 is the square root and then one over the square root is n.5
1:19:14
and then one over the square root is n.5
1:19:14
and then one over the square root is n.5 if we scale it in this way then we see
1:19:16
if we scale it in this way then we see
1:19:16
if we scale it in this way then we see that we actually get um
1:19:19
that we actually get um
1:19:20
that we actually get um one
1:19:21
one
1:19:21
one so this is a way to control the growth
1:19:24
so this is a way to control the growth
1:19:24
so this is a way to control the growth of of activations inside the residual
1:19:26
of of activations inside the residual
1:19:26
of of activations inside the residual stream in the forward pass and so we'd
1:19:29
stream in the forward pass and so we'd
1:19:29
stream in the forward pass and so we'd like to initialize in the same way where
1:19:31
like to initialize in the same way where
1:19:31
like to initialize in the same way where these weights that are at the end of
1:19:33
these weights that are at the end of
1:19:33
these weights that are at the end of each block so this C uh layer uh the gbt
1:19:38
each block so this C uh layer uh the gbt
1:19:38
each block so this C uh layer uh the gbt paper proposes to scale down those
1:19:40
paper proposes to scale down those
1:19:40
paper proposes to scale down those weights by one over the square root of
1:19:42
weights by one over the square root of
1:19:42
weights by one over the square root of the number of residual
1:19:43
the number of residual
1:19:43
the number of residual layers so one crude way to implement
1:19:46
layers so one crude way to implement
1:19:46
layers so one crude way to implement this is the following I don't know if
1:19:48
this is the following I don't know if
1:19:48
this is the following I don't know if this is uh pyro sanctioned but it works
1:19:50
this is uh pyro sanctioned but it works
1:19:50
this is uh pyro sanctioned but it works for me is we'll do in the
1:19:53
for me is we'll do in the
1:19:53
for me is we'll do in the initialization see that s that do
1:19:56
initialization see that s that do
1:19:56
initialization see that s that do special nanog
1:19:58
special nanog
1:19:58
special nanog GPT uh scale in it is one so we're
1:20:04
GPT uh scale in it is one so we're
1:20:04
GPT uh scale in it is one so we're setting um kind of like a flag for this
1:20:06
setting um kind of like a flag for this
1:20:06
setting um kind of like a flag for this module there must be a better way in py
1:20:08
module there must be a better way in py
1:20:08
module there must be a better way in py torch right but I don't
1:20:11
torch right but I don't
1:20:11
torch right but I don't know okay so we're basically attaching
1:20:13
know okay so we're basically attaching
1:20:13
know okay so we're basically attaching this flag and trying to make sure that
1:20:16
this flag and trying to make sure that
1:20:16
this flag and trying to make sure that it doesn't conflict with anything
1:20:17
it doesn't conflict with anything
1:20:17
it doesn't conflict with anything previously and then when we come down
1:20:20
previously and then when we come down
1:20:20
previously and then when we come down here this STD should be 0.02 by default
1:20:25
here this STD should be 0.02 by default
1:20:25
here this STD should be 0.02 by default but then if
1:20:27
but then if
1:20:27
but then if haat um module of this thing
1:20:31
haat um module of this thing
1:20:31
haat um module of this thing then STD *
1:20:34
then STD *
1:20:34
then STD * equals
1:20:36
equals
1:20:36
equals um copal is not guessing correctly uh so
1:20:39
um copal is not guessing correctly uh so
1:20:39
um copal is not guessing correctly uh so we want one over the square root of the
1:20:41
we want one over the square root of the
1:20:41
we want one over the square root of the number of layers so
1:20:44
number of layers so
1:20:44
number of layers so um the number of residual layers here is
1:20:47
um the number of residual layers here is
1:20:47
um the number of residual layers here is twice
1:20:48
twice
1:20:48
twice times Salt out config layers and then
1:20:52
times Salt out config layers and then
1:20:52
times Salt out config layers and then this times .5 so we want to scale down
1:20:57
this times .5 so we want to scale down
1:20:57
this times .5 so we want to scale down that standard deviation and this should
1:20:59
that standard deviation and this should
1:20:59
that standard deviation and this should be um correct and Implement that I
1:21:02
be um correct and Implement that I
1:21:03
be um correct and Implement that I should clarify by the way that the two
1:21:04
should clarify by the way that the two
1:21:04
should clarify by the way that the two times number of layers comes from the
1:21:06
times number of layers comes from the
1:21:06
times number of layers comes from the fact that every single one of our layers
1:21:07
fact that every single one of our layers
1:21:07
fact that every single one of our layers in the Transformer actually has two
1:21:09
in the Transformer actually has two
1:21:09
in the Transformer actually has two blocks that add to the ridal pathway
1:21:11
blocks that add to the ridal pathway
1:21:11
blocks that add to the ridal pathway right we have the attention and then the
1:21:13
right we have the attention and then the
1:21:13
right we have the attention and then the MLP so that's where the two times comes
1:21:15
MLP so that's where the two times comes
1:21:16
MLP so that's where the two times comes from and the other thing to mention is
1:21:18
from and the other thing to mention is
1:21:18
from and the other thing to mention is that uh what's slightly awkward but
1:21:21
that uh what's slightly awkward but
1:21:21
that uh what's slightly awkward but we're not going to fix it is that um
1:21:23
we're not going to fix it is that um
1:21:23
we're not going to fix it is that um because we are weight sharing the wte
1:21:26
because we are weight sharing the wte
1:21:26
because we are weight sharing the wte and the LM head in this iteration of our
1:21:29
and the LM head in this iteration of our
1:21:29
and the LM head in this iteration of our old subm modules we're going to actually
1:21:31
old subm modules we're going to actually
1:21:31
old subm modules we're going to actually come around to that tensor twice so
1:21:33
come around to that tensor twice so
1:21:33
come around to that tensor twice so we're going to first initialize it as an
1:21:34
we're going to first initialize it as an
1:21:34
we're going to first initialize it as an embedding with 0.02 and then we're going
1:21:37
embedding with 0.02 and then we're going
1:21:37
embedding with 0.02 and then we're going to come back around it again in a linear
1:21:39
to come back around it again in a linear
1:21:39
to come back around it again in a linear and initialize it again using 0.02 and
1:21:42
and initialize it again using 0.02 and
1:21:42
and initialize it again using 0.02 and it's going to be 0.02 because the LM
1:21:44
it's going to be 0.02 because the LM
1:21:44
it's going to be 0.02 because the LM head is of course not not scaled so it's
1:21:46
head is of course not not scaled so it's
1:21:46
head is of course not not scaled so it's not going to come here it's just it's
1:21:48
not going to come here it's just it's
1:21:48
not going to come here it's just it's going to be basically initialized twice
1:21:50
going to be basically initialized twice
1:21:50
going to be basically initialized twice using the identical same initialization
1:21:52
using the identical same initialization
1:21:52
using the identical same initialization but that's okay and then scrolling over
1:21:56
but that's okay and then scrolling over
1:21:56
but that's okay and then scrolling over here I added uh some code here so that
1:21:59
here I added uh some code here so that
1:21:59
here I added uh some code here so that we have
1:22:00
we have
1:22:00
we have reproducibility um to set the seeds and
1:22:03
reproducibility um to set the seeds and
1:22:03
reproducibility um to set the seeds and now we should be able to python train
1:22:05
now we should be able to python train
1:22:05
now we should be able to python train gpt2 pi and let this running and as far
1:22:09
gpt2 pi and let this running and as far
1:22:09
gpt2 pi and let this running and as far as I know this is the gpt2
1:22:11
as I know this is the gpt2
1:22:11
as I know this is the gpt2 initialization uh in the way we've
1:22:12
initialization uh in the way we've
1:22:12
initialization uh in the way we've implemented it right now so this
1:22:16
implemented it right now so this
1:22:16
implemented it right now so this looks uh reasonable to me okay so at
1:22:19
looks uh reasonable to me okay so at
1:22:19
looks uh reasonable to me okay so at this point we have the gpt2 model we
1:22:21
this point we have the gpt2 model we
1:22:21
this point we have the gpt2 model we have some confidence that it's correctly
1:22:23
have some confidence that it's correctly
1:22:23
have some confidence that it's correctly implemented we've initialized it
1:22:24
implemented we've initialized it
1:22:24
implemented we've initialized it properly and we have a data loader
1:22:26
properly and we have a data loader
1:22:26
properly and we have a data loader that's iterating through data batches
1:22:27
that's iterating through data batches
1:22:27
that's iterating through data batches and we can train so now comes the fun
1:22:30
and we can train so now comes the fun
1:22:30
and we can train so now comes the fun part I'd like us to speed up the
1:22:31
part I'd like us to speed up the
1:22:31
part I'd like us to speed up the training by a lot so we're getting our
1:22:33
training by a lot so we're getting our
1:22:33
training by a lot so we're getting our money's worth with respect to the
1:22:34
money's worth with respect to the
1:22:34
money's worth with respect to the hardware that we are uh using here and
1:22:38
hardware that we are uh using here and
1:22:38
hardware that we are uh using here and uh we're going to speed up the training
1:22:39
uh we're going to speed up the training
1:22:39
uh we're going to speed up the training by quite a bit uh now you always want to
1:22:42
by quite a bit uh now you always want to
1:22:42
by quite a bit uh now you always want to start with what Hardware do you have
1:22:44
start with what Hardware do you have
1:22:44
start with what Hardware do you have what does it offer and are you fully
1:22:45
what does it offer and are you fully
1:22:45
what does it offer and are you fully utilizing it so in my case if we go to
1:22:48
utilizing it so in my case if we go to
1:22:48
utilizing it so in my case if we go to Nvidia
1:22:49
Nvidia
1:22:49
Nvidia SMI we can see
1:22:53
SMI we can see
1:22:53
SMI we can see that I have eight gpus and each one of
1:22:56
that I have eight gpus and each one of
1:22:57
that I have eight gpus and each one of those gpus is an a100 sxm 80 gb so this
1:23:01
those gpus is an a100 sxm 80 gb so this
1:23:01
those gpus is an a100 sxm 80 gb so this is the GPU that I have available to me
1:23:03
is the GPU that I have available to me
1:23:03
is the GPU that I have available to me in this box now when I look when I use
1:23:07
in this box now when I look when I use
1:23:07
in this box now when I look when I use um to spin up these kinds of Boxes by
1:23:09
um to spin up these kinds of Boxes by
1:23:09
um to spin up these kinds of Boxes by the way my favorite place to go to is
1:23:11
the way my favorite place to go to is
1:23:11
the way my favorite place to go to is Lambda Labs um they do sponsor my
1:23:14
Lambda Labs um they do sponsor my
1:23:14
Lambda Labs um they do sponsor my development and that of my projects uh
1:23:17
development and that of my projects uh
1:23:17
development and that of my projects uh but I this is my favorite place to go
1:23:19
but I this is my favorite place to go
1:23:20
but I this is my favorite place to go and this is where you can spin up one of
1:23:21
and this is where you can spin up one of
1:23:21
and this is where you can spin up one of these machines and you pay per hour and
1:23:23
these machines and you pay per hour and
1:23:23
these machines and you pay per hour and it's very very simple
1:23:24
it's very very simple
1:23:25
it's very very simple so I like to spin them up and then
1:23:26
so I like to spin them up and then
1:23:26
so I like to spin them up and then connect vsod to it and that's how I
1:23:28
connect vsod to it and that's how I
1:23:28
connect vsod to it and that's how I develop now when we look at the A1 100s
1:23:30
develop now when we look at the A1 100s
1:23:30
develop now when we look at the A1 100s that are available here a100 80 GB sxm
1:23:34
that are available here a100 80 GB sxm
1:23:35
that are available here a100 80 GB sxm is the um GPU that I have here and we
1:23:38
is the um GPU that I have here and we
1:23:39
is the um GPU that I have here and we have a bunch of numbers here for um how
1:23:41
have a bunch of numbers here for um how
1:23:41
have a bunch of numbers here for um how many calculations you can expect out of
1:23:43
many calculations you can expect out of
1:23:43
many calculations you can expect out of this GPU so when I come over here
1:23:46
this GPU so when I come over here
1:23:46
this GPU so when I come over here and I break in right after here so
1:23:50
and I break in right after here so
1:23:50
and I break in right after here so python
1:23:51
python
1:23:51
python trity so I'm breaking in right after we
1:23:53
trity so I'm breaking in right after we
1:23:53
trity so I'm breaking in right after we calculate the loit and
1:23:55
calculate the loit and
1:23:55
calculate the loit and laws and the interesting thing I'd like
1:23:57
laws and the interesting thing I'd like
1:23:57
laws and the interesting thing I'd like you to note is when I do lit. dtype this
1:24:02
you to note is when I do lit. dtype this
1:24:02
you to note is when I do lit. dtype this prints a torch. FL 32 so by default iny
1:24:06
prints a torch. FL 32 so by default iny
1:24:06
prints a torch. FL 32 so by default iny torch when you create tensors um and
1:24:08
torch when you create tensors um and
1:24:08
torch when you create tensors um and this is the case for all the activations
1:24:10
this is the case for all the activations
1:24:10
this is the case for all the activations and for the parameters of the network
1:24:11
and for the parameters of the network
1:24:11
and for the parameters of the network and so on by default everything is in
1:24:13
and so on by default everything is in
1:24:13
and so on by default everything is in float 32 that means that every single
1:24:17
float 32 that means that every single
1:24:17
float 32 that means that every single number activation or weight and so on is
1:24:20
number activation or weight and so on is
1:24:20
number activation or weight and so on is using a float representation that has 32
1:24:23
using a float representation that has 32
1:24:23
using a float representation that has 32 bits and uh that's actually quite a bit
1:24:26
bits and uh that's actually quite a bit
1:24:26
bits and uh that's actually quite a bit of memory and it turns out empirically
1:24:27
of memory and it turns out empirically
1:24:27
of memory and it turns out empirically that for deep learning as a
1:24:28
that for deep learning as a
1:24:28
that for deep learning as a computational workload this is way too
1:24:30
computational workload this is way too
1:24:30
computational workload this is way too much and deep learning and the training
1:24:32
much and deep learning and the training
1:24:32
much and deep learning and the training of these networks can tolerate
1:24:34
of these networks can tolerate
1:24:34
of these networks can tolerate significantly lower precisions um not
1:24:37
significantly lower precisions um not
1:24:37
significantly lower precisions um not all computational workflows can tolerate
1:24:39
all computational workflows can tolerate
1:24:39
all computational workflows can tolerate small Precision so for example um if we
1:24:43
small Precision so for example um if we
1:24:43
small Precision so for example um if we go back to to the data sheet you'll see
1:24:45
go back to to the data sheet you'll see
1:24:45
go back to to the data sheet you'll see that actually these gpus support up to
1:24:48
that actually these gpus support up to
1:24:48
that actually these gpus support up to fp64 and this is quite useful I
1:24:50
fp64 and this is quite useful I
1:24:50
fp64 and this is quite useful I understand for a lot of um scientific
1:24:52
understand for a lot of um scientific
1:24:52
understand for a lot of um scientific Computing applications and there really
1:24:54
Computing applications and there really
1:24:54
Computing applications and there really need this uh but we don't need that much
1:24:56
need this uh but we don't need that much
1:24:56
need this uh but we don't need that much Precision for deep learning training So
1:24:59
Precision for deep learning training So
1:24:59
Precision for deep learning training So currently we are here
1:25:01
currently we are here
1:25:01
currently we are here fp32 and with this code as it is right
1:25:04
fp32 and with this code as it is right
1:25:04
fp32 and with this code as it is right now we expect to get at at most 19.5
1:25:07
now we expect to get at at most 19.5
1:25:08
now we expect to get at at most 19.5 Tera flops of performance that means
1:25:10
Tera flops of performance that means
1:25:10
Tera flops of performance that means we're doing 19.5 trillion operations
1:25:13
we're doing 19.5 trillion operations
1:25:13
we're doing 19.5 trillion operations floating Point operations so this is
1:25:15
floating Point operations so this is
1:25:15
floating Point operations so this is floating Point multiply add most um most
1:25:20
floating Point multiply add most um most
1:25:20
floating Point multiply add most um most likely and so these are the floating
1:25:23
likely and so these are the floating
1:25:23
likely and so these are the floating Point operations
1:25:25
Point operations
1:25:25
Point operations uh now notice that if we are willing to
1:25:27
uh now notice that if we are willing to
1:25:27
uh now notice that if we are willing to go down in Precision so tf32 is a lower
1:25:31
go down in Precision so tf32 is a lower
1:25:31
go down in Precision so tf32 is a lower Precision format we're going to see in a
1:25:32
Precision format we're going to see in a
1:25:32
Precision format we're going to see in a second you can actually get an 8X
1:25:34
second you can actually get an 8X
1:25:34
second you can actually get an 8X Improvement here and if you're willing
1:25:36
Improvement here and if you're willing
1:25:36
Improvement here and if you're willing to go down to float 16 or B float 16 you
1:25:39
to go down to float 16 or B float 16 you
1:25:39
to go down to float 16 or B float 16 you can actually get time 16x performance
1:25:42
can actually get time 16x performance
1:25:42
can actually get time 16x performance all the way to 312 Tera flops you see
1:25:45
all the way to 312 Tera flops you see
1:25:45
all the way to 312 Tera flops you see here that Nvidia likes to site numbers
1:25:47
here that Nvidia likes to site numbers
1:25:47
here that Nvidia likes to site numbers that have an asterisk here this asterisk
1:25:50
that have an asterisk here this asterisk
1:25:50
that have an asterisk here this asterisk uh says with sparsity uh but we are not
1:25:52
uh says with sparsity uh but we are not
1:25:52
uh says with sparsity uh but we are not going to be using sparsity in R code and
1:25:55
going to be using sparsity in R code and
1:25:55
going to be using sparsity in R code and I don't know that this is very widely
1:25:56
I don't know that this is very widely
1:25:56
I don't know that this is very widely used in the industry right now so most
1:25:58
used in the industry right now so most
1:25:58
used in the industry right now so most people look at this number here uh
1:26:01
people look at this number here uh
1:26:01
people look at this number here uh without sparcity and you'll notice that
1:26:03
without sparcity and you'll notice that
1:26:03
without sparcity and you'll notice that we could have got even more here but
1:26:05
we could have got even more here but
1:26:05
we could have got even more here but this is int 8 and int 8 is used for
1:26:08
this is int 8 and int 8 is used for
1:26:08
this is int 8 and int 8 is used for inference not for training uh because
1:26:11
inference not for training uh because
1:26:11
inference not for training uh because int 8 has a um it basically has um
1:26:17
int 8 has a um it basically has um
1:26:17
int 8 has a um it basically has um uniform
1:26:18
uniform
1:26:18
uniform spacing um and uh we actually require a
1:26:21
spacing um and uh we actually require a
1:26:21
spacing um and uh we actually require a float so that we get a better match to
1:26:24
float so that we get a better match to
1:26:24
float so that we get a better match to the uh normal distributions that occur
1:26:28
the uh normal distributions that occur
1:26:28
the uh normal distributions that occur during training of neural networks where
1:26:29
during training of neural networks where
1:26:29
during training of neural networks where both activations and weights are
1:26:31
both activations and weights are
1:26:31
both activations and weights are distributed as a normal distribution and
1:26:33
distributed as a normal distribution and
1:26:33
distributed as a normal distribution and so uh floating points are really
1:26:35
so uh floating points are really
1:26:35
so uh floating points are really important to to match that uh
1:26:38
important to to match that uh
1:26:38
important to to match that uh representation so we're not typically
1:26:40
representation so we're not typically
1:26:40
representation so we're not typically using int 8 uh for training but we are
1:26:42
using int 8 uh for training but we are
1:26:42
using int 8 uh for training but we are using it for inference and if we bring
1:26:45
using it for inference and if we bring
1:26:45
using it for inference and if we bring down the Precision we can get a lot more
1:26:47
down the Precision we can get a lot more
1:26:47
down the Precision we can get a lot more Terra flops out of the tensor course
1:26:49
Terra flops out of the tensor course
1:26:49
Terra flops out of the tensor course available in the gpus we'll talk about
1:26:51
available in the gpus we'll talk about
1:26:51
available in the gpus we'll talk about that in a second but in addition to that
1:26:53
that in a second but in addition to that
1:26:53
that in a second but in addition to that if all of these numbers have fewer bits
1:26:56
if all of these numbers have fewer bits
1:26:56
if all of these numbers have fewer bits of representation it's going to be much
1:26:58
of representation it's going to be much
1:26:58
of representation it's going to be much easier to move them around and that's
1:27:00
easier to move them around and that's
1:27:00
easier to move them around and that's where we start to get into the memory
1:27:02
where we start to get into the memory
1:27:02
where we start to get into the memory bandwidth and the memory of the model so
1:27:04
bandwidth and the memory of the model so
1:27:04
bandwidth and the memory of the model so not only do we have a finite capacity of
1:27:06
not only do we have a finite capacity of
1:27:06
not only do we have a finite capacity of the number of bits that our GPU can
1:27:08
the number of bits that our GPU can
1:27:08
the number of bits that our GPU can store but in addition to that there's a
1:27:11
store but in addition to that there's a
1:27:11
store but in addition to that there's a speed with which you can access this
1:27:13
speed with which you can access this
1:27:13
speed with which you can access this memory um and you have a certain memory
1:27:16
memory um and you have a certain memory
1:27:16
memory um and you have a certain memory bandwidth it's a very precious resource
1:27:19
bandwidth it's a very precious resource
1:27:19
bandwidth it's a very precious resource and in fact many of the deep learning uh
1:27:21
and in fact many of the deep learning uh
1:27:21
and in fact many of the deep learning uh work workloads for training are memory
1:27:23
work workloads for training are memory
1:27:23
work workloads for training are memory bound and what that means is actually
1:27:25
bound and what that means is actually
1:27:25
bound and what that means is actually that the tensor cores that do all these
1:27:27
that the tensor cores that do all these
1:27:27
that the tensor cores that do all these extremely fast multiplications most of
1:27:29
extremely fast multiplications most of
1:27:29
extremely fast multiplications most of the time they're waiting around they're
1:27:31
the time they're waiting around they're
1:27:31
the time they're waiting around they're idle um because we can't feed them with
1:27:34
idle um because we can't feed them with
1:27:34
idle um because we can't feed them with data fast enough we can't load the data
1:27:36
data fast enough we can't load the data
1:27:37
data fast enough we can't load the data fast enough from memory so typical
1:27:38
fast enough from memory so typical
1:27:38
fast enough from memory so typical utilizations of your Hardware if you're
1:27:40
utilizations of your Hardware if you're
1:27:40
utilizations of your Hardware if you're getting 60% uh utilization you're
1:27:43
getting 60% uh utilization you're
1:27:43
getting 60% uh utilization you're actually doing extremely well um so half
1:27:46
actually doing extremely well um so half
1:27:46
actually doing extremely well um so half of the time in a well-tuned application
1:27:48
of the time in a well-tuned application
1:27:48
of the time in a well-tuned application your tensor cores are not doing
1:27:50
your tensor cores are not doing
1:27:50
your tensor cores are not doing multiplies because the data is not
1:27:51
multiplies because the data is not
1:27:51
multiplies because the data is not available so the memory bandwidth here
1:27:53
available so the memory bandwidth here
1:27:53
available so the memory bandwidth here is extremely important as well and if we
1:27:55
is extremely important as well and if we
1:27:55
is extremely important as well and if we come down in the Precision for all the
1:27:58
come down in the Precision for all the
1:27:58
come down in the Precision for all the floats all the numbers weights and
1:28:00
floats all the numbers weights and
1:28:00
floats all the numbers weights and activations suddenly require less memory
1:28:02
activations suddenly require less memory
1:28:02
activations suddenly require less memory so we can store more and we can access
1:28:05
so we can store more and we can access
1:28:05
so we can store more and we can access it faster so everything speeds up and
1:28:07
it faster so everything speeds up and
1:28:07
it faster so everything speeds up and it's amazing and now let's reap the
1:28:09
it's amazing and now let's reap the
1:28:09
it's amazing and now let's reap the benefits of it um and let's first look
1:28:12
benefits of it um and let's first look
1:28:12
benefits of it um and let's first look at the tensor float 32
1:28:14
at the tensor float 32
1:28:14
at the tensor float 32 format okay so first of all what are
1:28:16
format okay so first of all what are
1:28:16
format okay so first of all what are tensor cores well tensor course tensor
1:28:19
tensor cores well tensor course tensor
1:28:19
tensor cores well tensor course tensor core is just an instruction in the a100
1:28:22
core is just an instruction in the a100
1:28:22
core is just an instruction in the a100 architecture right so so what it does is
1:28:25
architecture right so so what it does is
1:28:25
architecture right so so what it does is it does basically a little 4x4 Matrix
1:28:27
it does basically a little 4x4 Matrix
1:28:27
it does basically a little 4x4 Matrix multiply so uh this is just matrix
1:28:30
multiply so uh this is just matrix
1:28:30
multiply so uh this is just matrix multiplication here of 4x4 matrices and
1:28:35
multiplication here of 4x4 matrices and
1:28:35
multiplication here of 4x4 matrices and there are multiple configurations as to
1:28:37
there are multiple configurations as to
1:28:38
there are multiple configurations as to what Precision any of these matrices are
1:28:40
what Precision any of these matrices are
1:28:40
what Precision any of these matrices are it in what Precision the internal
1:28:42
it in what Precision the internal
1:28:42
it in what Precision the internal accumulate happens and then what is the
1:28:45
accumulate happens and then what is the
1:28:45
accumulate happens and then what is the output Precision input precisions Etc so
1:28:47
output Precision input precisions Etc so
1:28:47
output Precision input precisions Etc so there's a few switches but it's
1:28:48
there's a few switches but it's
1:28:48
there's a few switches but it's basically a 4x4 multiply and then
1:28:51
basically a 4x4 multiply and then
1:28:51
basically a 4x4 multiply and then anytime we have any operations that
1:28:53
anytime we have any operations that
1:28:53
anytime we have any operations that require Magic multiplication uh they get
1:28:55
require Magic multiplication uh they get
1:28:55
require Magic multiplication uh they get broken up into these into this
1:28:58
broken up into these into this
1:28:58
broken up into these into this instruction of little 4x4 multiply and
1:29:00
instruction of little 4x4 multiply and
1:29:00
instruction of little 4x4 multiply and so everything gets broken up into this
1:29:02
so everything gets broken up into this
1:29:02
so everything gets broken up into this instruction because it's the fastest way
1:29:04
instruction because it's the fastest way
1:29:04
instruction because it's the fastest way to multiply matrices and it turns out
1:29:06
to multiply matrices and it turns out
1:29:06
to multiply matrices and it turns out that most of the computational work that
1:29:08
that most of the computational work that
1:29:08
that most of the computational work that we're doing up above uh all of it really
1:29:10
we're doing up above uh all of it really
1:29:10
we're doing up above uh all of it really is matrix multiplication most of the
1:29:12
is matrix multiplication most of the
1:29:12
is matrix multiplication most of the work computationally happens in the
1:29:14
work computationally happens in the
1:29:14
work computationally happens in the linear layers um linear linear Etc
1:29:20
linear layers um linear linear Etc
1:29:20
linear layers um linear linear Etc there's a few things sandwiched in
1:29:21
there's a few things sandwiched in
1:29:21
there's a few things sandwiched in between so there's some additions in
1:29:23
between so there's some additions in
1:29:23
between so there's some additions in residuals there's some G nonlinearities
1:29:25
residuals there's some G nonlinearities
1:29:25
residuals there's some G nonlinearities there's some layer Norms Etc but if you
1:29:28
there's some layer Norms Etc but if you
1:29:28
there's some layer Norms Etc but if you just time them you'll see that these are
1:29:30
just time them you'll see that these are
1:29:30
just time them you'll see that these are nothing like basically the in
1:29:32
nothing like basically the in
1:29:32
nothing like basically the in Transformer is just a bunch of Matrix
1:29:34
Transformer is just a bunch of Matrix
1:29:34
Transformer is just a bunch of Matrix multiplications really um and especially
1:29:37
multiplications really um and especially
1:29:37
multiplications really um and especially at this small scale 124 million
1:29:39
at this small scale 124 million
1:29:39
at this small scale 124 million parameter model actually the biggest
1:29:42
parameter model actually the biggest
1:29:42
parameter model actually the biggest matrix multiplication by far is the
1:29:44
matrix multiplication by far is the
1:29:44
matrix multiplication by far is the classifier layer at the top that is a
1:29:46
classifier layer at the top that is a
1:29:46
classifier layer at the top that is a massive Matrix multiply of going from
1:29:48
massive Matrix multiply of going from
1:29:49
massive Matrix multiply of going from 768 to
1:29:50
768 to
1:29:50
768 to 50257 and that Matrix multiply dominates
1:29:53
50257 and that Matrix multiply dominates
1:29:53
50257 and that Matrix multiply dominates anything else that happens in that
1:29:55
anything else that happens in that
1:29:55
anything else that happens in that Network roughly speaking so it's Matrix
1:29:58
Network roughly speaking so it's Matrix
1:29:58
Network roughly speaking so it's Matrix multiplies that become a lot faster
1:30:00
multiplies that become a lot faster
1:30:00
multiplies that become a lot faster which are hidden inside our linear
1:30:02
which are hidden inside our linear
1:30:02
which are hidden inside our linear layers and they're accelerated through
1:30:05
layers and they're accelerated through
1:30:05
layers and they're accelerated through tensor course now the best reference I
1:30:07
tensor course now the best reference I
1:30:07
tensor course now the best reference I would say for tensor course is basically
1:30:09
would say for tensor course is basically
1:30:09
would say for tensor course is basically just go to the um a 100 architecture
1:30:12
just go to the um a 100 architecture
1:30:13
just go to the um a 100 architecture white paper and then it's pretty
1:30:15
white paper and then it's pretty
1:30:15
white paper and then it's pretty detailed and but I think people it's
1:30:18
detailed and but I think people it's
1:30:18
detailed and but I think people it's like relatively readable mostly if you
1:30:20
like relatively readable mostly if you
1:30:20
like relatively readable mostly if you half understand what's happening um so
1:30:23
half understand what's happening um so
1:30:23
half understand what's happening um so figure 9 tensor float
1:30:26
figure 9 tensor float
1:30:26
figure 9 tensor float 32 so this is the explanation basically
1:30:28
32 so this is the explanation basically
1:30:28
32 so this is the explanation basically for tf32 and what happens here and you
1:30:31
for tf32 and what happens here and you
1:30:31
for tf32 and what happens here and you see that there's many configuration
1:30:32
see that there's many configuration
1:30:32
see that there's many configuration options here available so the input
1:30:35
options here available so the input
1:30:35
options here available so the input operands and what precisions are they in
1:30:37
operands and what precisions are they in
1:30:37
operands and what precisions are they in the accumulator and um what um basically
1:30:41
the accumulator and um what um basically
1:30:41
the accumulator and um what um basically the um the internal representation
1:30:44
the um the internal representation
1:30:44
the um the internal representation within the instruction when you do the
1:30:46
within the instruction when you do the
1:30:46
within the instruction when you do the accumulate of this matrix
1:30:48
accumulate of this matrix
1:30:48
accumulate of this matrix multiplication so the intermediate plus
1:30:51
multiplication so the intermediate plus
1:30:51
multiplication so the intermediate plus equals um of the intermediate little
1:30:53
equals um of the intermediate little
1:30:53
equals um of the intermediate little vector multiplies here that all happens
1:30:55
vector multiplies here that all happens
1:30:55
vector multiplies here that all happens in
1:30:57
in
1:30:57
in fp32 and then uh this is an aex
1:31:00
fp32 and then uh this is an aex
1:31:00
fp32 and then uh this is an aex improvement as I mentioned to the Ops
1:31:01
improvement as I mentioned to the Ops
1:31:01
improvement as I mentioned to the Ops that we get so tf32 specifically we're
1:31:04
that we get so tf32 specifically we're
1:31:04
that we get so tf32 specifically we're looking at this row here and the way
1:31:06
looking at this row here and the way
1:31:06
looking at this row here and the way this works
1:31:07
this works
1:31:07
this works is
1:31:10
is
1:31:10
is um normally fp32 has 32 bits
1:31:14
um normally fp32 has 32 bits
1:31:14
um normally fp32 has 32 bits tf32 is the exact same bits we have one
1:31:18
tf32 is the exact same bits we have one
1:31:18
tf32 is the exact same bits we have one sign bit we have eight exponent bits
1:31:21
sign bit we have eight exponent bits
1:31:21
sign bit we have eight exponent bits except the mantisa bits get cropped in
1:31:24
except the mantisa bits get cropped in
1:31:24
except the mantisa bits get cropped in the float and so basically um we end up
1:31:27
the float and so basically um we end up
1:31:27
the float and so basically um we end up with just 19 bits instead of 32 bits
1:31:30
with just 19 bits instead of 32 bits
1:31:30
with just 19 bits instead of 32 bits because the last 133 bits get truncated
1:31:33
because the last 133 bits get truncated
1:31:33
because the last 133 bits get truncated they get dropped um and all this is
1:31:36
they get dropped um and all this is
1:31:36
they get dropped um and all this is internal to the instruction so none of
1:31:38
internal to the instruction so none of
1:31:38
internal to the instruction so none of it is visible to anything in our pytorch
1:31:41
it is visible to anything in our pytorch
1:31:41
it is visible to anything in our pytorch uh none of our pytorch code will change
1:31:43
uh none of our pytorch code will change
1:31:43
uh none of our pytorch code will change all of the numbers will look identical
1:31:45
all of the numbers will look identical
1:31:45
all of the numbers will look identical it's just that when you call the tensor
1:31:47
it's just that when you call the tensor
1:31:47
it's just that when you call the tensor core um instruction internally in the
1:31:50
core um instruction internally in the
1:31:50
core um instruction internally in the hardware it will crop out these 13 bits
1:31:54
hardware it will crop out these 13 bits
1:31:54
hardware it will crop out these 13 bits and that allows it to uh calculate this
1:31:57
and that allows it to uh calculate this
1:31:57
and that allows it to uh calculate this little Matrix multiply significantly
1:31:59
little Matrix multiply significantly
1:31:59
little Matrix multiply significantly faster 8X faster now of course this
1:32:02
faster 8X faster now of course this
1:32:02
faster 8X faster now of course this speed up comes at a cost and the cost is
1:32:04
speed up comes at a cost and the cost is
1:32:04
speed up comes at a cost and the cost is that we are reducing the Precision our
1:32:07
that we are reducing the Precision our
1:32:07
that we are reducing the Precision our accumulate is still an fp32 our output
1:32:09
accumulate is still an fp32 our output
1:32:09
accumulate is still an fp32 our output is fp32 our inputs are fp32 but
1:32:12
is fp32 our inputs are fp32 but
1:32:12
is fp32 our inputs are fp32 but internally things get truncated in the
1:32:14
internally things get truncated in the
1:32:14
internally things get truncated in the operand to perform the operation faster
1:32:17
operand to perform the operation faster
1:32:17
operand to perform the operation faster and so our results are starting to be a
1:32:19
and so our results are starting to be a
1:32:19
and so our results are starting to be a bit more approximate but empirically
1:32:21
bit more approximate but empirically
1:32:21
bit more approximate but empirically when you actually train with this you
1:32:22
when you actually train with this you
1:32:22
when you actually train with this you basically can't tell the difference
1:32:24
basically can't tell the difference
1:32:24
basically can't tell the difference so the reason I like tf32 is because if
1:32:26
so the reason I like tf32 is because if
1:32:26
so the reason I like tf32 is because if you can tolerate a little bit of a
1:32:28
you can tolerate a little bit of a
1:32:28
you can tolerate a little bit of a Precision fudge um then this is free
1:32:32
Precision fudge um then this is free
1:32:32
Precision fudge um then this is free like none of your codes sees this it's
1:32:34
like none of your codes sees this it's
1:32:34
like none of your codes sees this it's fully internal to the operation and the
1:32:36
fully internal to the operation and the
1:32:36
fully internal to the operation and the operation to you just go 8X faster and
1:32:39
operation to you just go 8X faster and
1:32:39
operation to you just go 8X faster and it's a bit more approximate and so it's
1:32:41
it's a bit more approximate and so it's
1:32:42
it's a bit more approximate and so it's a pretty sweet spot I would say in
1:32:43
a pretty sweet spot I would say in
1:32:43
a pretty sweet spot I would say in optimization and uh let's see what that
1:32:46
optimization and uh let's see what that
1:32:46
optimization and uh let's see what that looks like first so I've set up our Cod
1:32:48
looks like first so I've set up our Cod
1:32:48
looks like first so I've set up our Cod to just time the uh iterations so import
1:32:51
to just time the uh iterations so import
1:32:51
to just time the uh iterations so import time I changed the hyper parameters so
1:32:54
time I changed the hyper parameters so
1:32:54
time I changed the hyper parameters so that we have something a bit more that
1:32:55
that we have something a bit more that
1:32:55
that we have something a bit more that reflects uh kind of workload that we
1:32:57
reflects uh kind of workload that we
1:32:57
reflects uh kind of workload that we want to run uh because we want to do a
1:32:59
want to run uh because we want to do a
1:32:59
want to run uh because we want to do a fairly large run at the end of this so
1:33:01
fairly large run at the end of this so
1:33:01
fairly large run at the end of this so let's use batch size 16 and let's now
1:33:04
let's use batch size 16 and let's now
1:33:04
let's use batch size 16 and let's now use the actual gpt2 um maximum sequence
1:33:07
use the actual gpt2 um maximum sequence
1:33:07
use the actual gpt2 um maximum sequence length of 10,24
1:33:08
length of 10,24
1:33:08
length of 10,24 tokens uh so this is the
1:33:11
tokens uh so this is the
1:33:11
tokens uh so this is the configuration and then for 50 iterations
1:33:15
configuration and then for 50 iterations
1:33:15
configuration and then for 50 iterations I'm just doing something very lazy here
1:33:17
I'm just doing something very lazy here
1:33:17
I'm just doing something very lazy here I'm doing time. time to get the current
1:33:19
I'm doing time. time to get the current
1:33:19
I'm doing time. time to get the current time and then this is the optimization
1:33:22
time and then this is the optimization
1:33:22
time and then this is the optimization Loop and now I want to time how long
1:33:24
Loop and now I want to time how long
1:33:24
Loop and now I want to time how long this takes now one issue with working
1:33:28
this takes now one issue with working
1:33:28
this takes now one issue with working with gpus is that as your
1:33:31
with gpus is that as your
1:33:32
with gpus is that as your CPU um when your CPU runs it's just
1:33:35
CPU um when your CPU runs it's just
1:33:35
CPU um when your CPU runs it's just scheduling work on GPU it's ordering
1:33:38
scheduling work on GPU it's ordering
1:33:38
scheduling work on GPU it's ordering some work right and so it send a request
1:33:40
some work right and so it send a request
1:33:40
some work right and so it send a request and then it continues running and so we
1:33:43
and then it continues running and so we
1:33:43
and then it continues running and so we can actually it can happen sometimes
1:33:44
can actually it can happen sometimes
1:33:44
can actually it can happen sometimes that we sort of um speed through this
1:33:48
that we sort of um speed through this
1:33:48
that we sort of um speed through this and we queue up a lot of kernels to run
1:33:50
and we queue up a lot of kernels to run
1:33:50
and we queue up a lot of kernels to run on the GPU and then the CPU sort of like
1:33:52
on the GPU and then the CPU sort of like
1:33:52
on the GPU and then the CPU sort of like gets here and takes time at time but
1:33:54
gets here and takes time at time but
1:33:54
gets here and takes time at time but actually the GPU is still running
1:33:56
actually the GPU is still running
1:33:56
actually the GPU is still running because it takes it time to actually
1:33:57
because it takes it time to actually
1:33:57
because it takes it time to actually work through the work that was scheduled
1:34:00
work through the work that was scheduled
1:34:00
work through the work that was scheduled to run and so you're just building up a
1:34:03
to run and so you're just building up a
1:34:03
to run and so you're just building up a queue for the GPU and so actually if you
1:34:05
queue for the GPU and so actually if you
1:34:05
queue for the GPU and so actually if you need to you want to wait toat data
1:34:07
need to you want to wait toat data
1:34:07
need to you want to wait toat data synchronize and this will wait for the
1:34:10
synchronize and this will wait for the
1:34:10
synchronize and this will wait for the GPU to finish all the work that was
1:34:12
GPU to finish all the work that was
1:34:12
GPU to finish all the work that was scheduled to run up above here and then
1:34:15
scheduled to run up above here and then
1:34:15
scheduled to run up above here and then we can actually take the time so
1:34:17
we can actually take the time so
1:34:17
we can actually take the time so basically we're waiting for the GPU to
1:34:19
basically we're waiting for the GPU to
1:34:19
basically we're waiting for the GPU to stop this iteration take time and then
1:34:22
stop this iteration take time and then
1:34:22
stop this iteration take time and then we're going to just print it so
1:34:24
we're going to just print it so
1:34:24
we're going to just print it so so here I'm going to run the training
1:34:26
so here I'm going to run the training
1:34:26
so here I'm going to run the training Loop and here on the right I'm watching
1:34:29
Loop and here on the right I'm watching
1:34:29
Loop and here on the right I'm watching Nvidia SMI so we start off at zero um
1:34:33
Nvidia SMI so we start off at zero um
1:34:33
Nvidia SMI so we start off at zero um we're not using the GPU and then by
1:34:35
we're not using the GPU and then by
1:34:35
we're not using the GPU and then by default P will use gpu0 so we see that
1:34:37
default P will use gpu0 so we see that
1:34:37
default P will use gpu0 so we see that it gets filled up and we're using 35 GB
1:34:40
it gets filled up and we're using 35 GB
1:34:40
it gets filled up and we're using 35 GB out of 80 gabt
1:34:42
out of 80 gabt
1:34:42
out of 80 gabt available and then here on the left we
1:34:44
available and then here on the left we
1:34:45
available and then here on the left we see that because we've cranked up the
1:34:47
see that because we've cranked up the
1:34:47
see that because we've cranked up the batch
1:34:48
batch
1:34:48
batch size now it's only 20 batches to do a
1:34:51
size now it's only 20 batches to do a
1:34:51
size now it's only 20 batches to do a single Epoch on our tiny Shakespeare
1:34:54
single Epoch on our tiny Shakespeare
1:34:54
single Epoch on our tiny Shakespeare and we see that we're seeing roughly a
1:34:55
and we see that we're seeing roughly a
1:34:55
and we see that we're seeing roughly a th000 milliseconds per iteration here
1:34:58
th000 milliseconds per iteration here
1:34:58
th000 milliseconds per iteration here right
1:35:00
right
1:35:00
right so the first iteration sometimes is
1:35:02
so the first iteration sometimes is
1:35:02
so the first iteration sometimes is slower and that's because pytorch might
1:35:04
slower and that's because pytorch might
1:35:04
slower and that's because pytorch might be doing a lot of initializations here
1:35:06
be doing a lot of initializations here
1:35:06
be doing a lot of initializations here on the very first iteration and so it's
1:35:08
on the very first iteration and so it's
1:35:08
on the very first iteration and so it's probably initializing all these uh
1:35:09
probably initializing all these uh
1:35:09
probably initializing all these uh tensors and buffers to hold all the
1:35:11
tensors and buffers to hold all the
1:35:11
tensors and buffers to hold all the gradients and I'm not 100% sure all the
1:35:13
gradients and I'm not 100% sure all the
1:35:13
gradients and I'm not 100% sure all the work that happens here but uh this could
1:35:16
work that happens here but uh this could
1:35:16
work that happens here but uh this could be a slower iteration when you're timing
1:35:18
be a slower iteration when you're timing
1:35:18
be a slower iteration when you're timing your logic you always want to be careful
1:35:19
your logic you always want to be careful
1:35:19
your logic you always want to be careful with that but basically we're seeing a
1:35:21
with that but basically we're seeing a
1:35:21
with that but basically we're seeing a th000 milliseconds per iteration
1:35:24
th000 milliseconds per iteration
1:35:24
th000 milliseconds per iteration um and so this will run for roughly 50
1:35:26
um and so this will run for roughly 50
1:35:26
um and so this will run for roughly 50 seconds as we have it right now so
1:35:29
seconds as we have it right now so
1:35:29
seconds as we have it right now so that's our Baseline in flo 32 one more
1:35:32
that's our Baseline in flo 32 one more
1:35:32
that's our Baseline in flo 32 one more thing I wanted to mention is that if
1:35:35
thing I wanted to mention is that if
1:35:35
thing I wanted to mention is that if this doesn't fit into your GPU and
1:35:36
this doesn't fit into your GPU and
1:35:36
this doesn't fit into your GPU and you're getting out of memory errors then
1:35:38
you're getting out of memory errors then
1:35:38
you're getting out of memory errors then start decreasing your batch size until
1:35:40
start decreasing your batch size until
1:35:40
start decreasing your batch size until things fit so instead of 16 try eight or
1:35:42
things fit so instead of 16 try eight or
1:35:42
things fit so instead of 16 try eight or four or whatever you need to fit um the
1:35:45
four or whatever you need to fit um the
1:35:46
four or whatever you need to fit um the batch into your GPU and if you have a
1:35:48
batch into your GPU and if you have a
1:35:48
batch into your GPU and if you have a bigger GPU you can actually potentially
1:35:49
bigger GPU you can actually potentially
1:35:49
bigger GPU you can actually potentially get away with 32 and so on uh by default
1:35:52
get away with 32 and so on uh by default
1:35:52
get away with 32 and so on uh by default you want to basically max out has Max
1:35:54
you want to basically max out has Max
1:35:54
you want to basically max out has Max Max out the batch size that fits on your
1:35:56
Max out the batch size that fits on your
1:35:56
Max out the batch size that fits on your GPU and you want to keep it nice numbers
1:35:59
GPU and you want to keep it nice numbers
1:35:59
GPU and you want to keep it nice numbers so use numbers that have lots of powers
1:36:01
so use numbers that have lots of powers
1:36:01
so use numbers that have lots of powers of two in them so 16 is a good number 8
1:36:05
of two in them so 16 is a good number 8
1:36:05
of two in them so 16 is a good number 8 24 32 48 These are nice numbers but
1:36:09
24 32 48 These are nice numbers but
1:36:09
24 32 48 These are nice numbers but don't use something like 17 uh because
1:36:11
don't use something like 17 uh because
1:36:11
don't use something like 17 uh because that will run very inefficiently on a
1:36:12
that will run very inefficiently on a
1:36:12
that will run very inefficiently on a GPU uh and we're going to see that a bit
1:36:14
GPU uh and we're going to see that a bit
1:36:14
GPU uh and we're going to see that a bit later as well so for now let's just
1:36:17
later as well so for now let's just
1:36:17
later as well so for now let's just stick with
1:36:18
stick with
1:36:18
stick with 16124 and uh the one thing that I added
1:36:22
16124 and uh the one thing that I added
1:36:22
16124 and uh the one thing that I added also here and I ran it again is I'm
1:36:25
also here and I ran it again is I'm
1:36:25
also here and I ran it again is I'm calculating a tokens per second
1:36:27
calculating a tokens per second
1:36:27
calculating a tokens per second throughput during training
1:36:29
throughput during training
1:36:29
throughput during training because we might end up changing the
1:36:31
because we might end up changing the
1:36:31
because we might end up changing the backat size around over time but tokens
1:36:34
backat size around over time but tokens
1:36:34
backat size around over time but tokens per second is the objective measure that
1:36:35
per second is the objective measure that
1:36:35
per second is the objective measure that we actually really care about how many
1:36:37
we actually really care about how many
1:36:37
we actually really care about how many tokens of data are we training on and
1:36:39
tokens of data are we training on and
1:36:39
tokens of data are we training on and what is the throughput of tokens that
1:36:41
what is the throughput of tokens that
1:36:41
what is the throughput of tokens that we're getting in our optimization so
1:36:43
we're getting in our optimization so
1:36:43
we're getting in our optimization so right now we're processing and training
1:36:44
right now we're processing and training
1:36:44
right now we're processing and training on 163,000 tokens per second roughly and
1:36:48
on 163,000 tokens per second roughly and
1:36:48
on 163,000 tokens per second roughly and that's a bit more objective
1:36:50
that's a bit more objective
1:36:50
that's a bit more objective metric okay so let's now enable tf32 now
1:36:53
metric okay so let's now enable tf32 now
1:36:53
metric okay so let's now enable tf32 now luckily pytorch makes this fairly easy
1:36:56
luckily pytorch makes this fairly easy
1:36:56
luckily pytorch makes this fairly easy for us and uh to enable tf32 you just
1:36:59
for us and uh to enable tf32 you just
1:36:59
for us and uh to enable tf32 you just need to do a single line and is this and
1:37:02
need to do a single line and is this and
1:37:02
need to do a single line and is this and when we go to the py documentation here
1:37:04
when we go to the py documentation here
1:37:04
when we go to the py documentation here for this function basically this tells
1:37:07
for this function basically this tells
1:37:07
for this function basically this tells pych what kind of kernels to run and by
1:37:10
pych what kind of kernels to run and by
1:37:10
pych what kind of kernels to run and by default I believe it is highest highest
1:37:12
default I believe it is highest highest
1:37:13
default I believe it is highest highest Precision for mat M and that means that
1:37:15
Precision for mat M and that means that
1:37:15
Precision for mat M and that means that everything happens in float 32 just like
1:37:18
everything happens in float 32 just like
1:37:18
everything happens in float 32 just like it did before but if we set it to high
1:37:20
it did before but if we set it to high
1:37:20
it did before but if we set it to high as we do right now Matrix
1:37:22
as we do right now Matrix
1:37:22
as we do right now Matrix multiplications will not use tensor flow
1:37:24
multiplications will not use tensor flow
1:37:24
multiplications will not use tensor flow 32 when it's
1:37:26
32 when it's
1:37:26
32 when it's available my GPU is a100 so it's an
1:37:30
available my GPU is a100 so it's an
1:37:30
available my GPU is a100 so it's an ampere series and therefore tf32 is
1:37:33
ampere series and therefore tf32 is
1:37:33
ampere series and therefore tf32 is available if you have an older GPU this
1:37:35
available if you have an older GPU this
1:37:35
available if you have an older GPU this might not be available for you but for
1:37:38
might not be available for you but for
1:37:38
might not be available for you but for my GPU it's available and so what I
1:37:39
my GPU it's available and so what I
1:37:39
my GPU it's available and so what I expect P to do is that every single
1:37:41
expect P to do is that every single
1:37:41
expect P to do is that every single place where we see an nn. linear inside
1:37:44
place where we see an nn. linear inside
1:37:44
place where we see an nn. linear inside there there's a matrix multiplication
1:37:46
there there's a matrix multiplication
1:37:46
there there's a matrix multiplication and I expect that matrix multiplication
1:37:48
and I expect that matrix multiplication
1:37:48
and I expect that matrix multiplication now to be um running on tensor course
1:37:51
now to be um running on tensor course
1:37:51
now to be um running on tensor course utilizing the TF 32%
1:37:55
so this is the single line of change
1:37:58
so this is the single line of change
1:37:58
so this is the single line of change that is I believe necessary and let's
1:37:59
that is I believe necessary and let's
1:37:59
that is I believe necessary and let's rerun this now we saw that um in terms
1:38:02
rerun this now we saw that um in terms
1:38:03
rerun this now we saw that um in terms of the throughput that is promised to us
1:38:05
of the throughput that is promised to us
1:38:05
of the throughput that is promised to us we're supposed to be getting 8X roughly
1:38:08
we're supposed to be getting 8X roughly
1:38:08
we're supposed to be getting 8X roughly so let's see what
1:38:10
so let's see what
1:38:10
so let's see what happens and that 8X came from here right
1:38:15
happens and that 8X came from here right
1:38:15
happens and that 8X came from here right um 8X and it also came from looking at
1:38:20
um 8X and it also came from looking at
1:38:20
um 8X and it also came from looking at it um here 156 T flops instead of of
1:38:24
it um here 156 T flops instead of of
1:38:24
it um here 156 T flops instead of of 19.5 okay so what actually happened uh
1:38:27
19.5 okay so what actually happened uh
1:38:27
19.5 okay so what actually happened uh so we're seeing that our throughput
1:38:29
so we're seeing that our throughput
1:38:29
so we're seeing that our throughput roughly 3x not aex so we are going we're
1:38:35
roughly 3x not aex so we are going we're
1:38:35
roughly 3x not aex so we are going we're from 1,000 milliseconds we're going down
1:38:37
from 1,000 milliseconds we're going down
1:38:37
from 1,000 milliseconds we're going down to 300 milliseconds and our throughput
1:38:39
to 300 milliseconds and our throughput
1:38:39
to 300 milliseconds and our throughput is now about 50,000 tokens per second so
1:38:41
is now about 50,000 tokens per second so
1:38:41
is now about 50,000 tokens per second so we have a roughly 3x instead of 8X so
1:38:43
we have a roughly 3x instead of 8X so
1:38:43
we have a roughly 3x instead of 8X so what happened and basically What's
1:38:45
what happened and basically What's
1:38:46
what happened and basically What's Happening Here is again a lot of these
1:38:48
Happening Here is again a lot of these
1:38:48
Happening Here is again a lot of these workloads are memory bound and so even
1:38:51
workloads are memory bound and so even
1:38:51
workloads are memory bound and so even though the
1:38:52
though the
1:38:52
though the tf32 offers in principle a lot faster
1:38:57
tf32 offers in principle a lot faster
1:38:57
tf32 offers in principle a lot faster throughput all of these numbers
1:38:59
throughput all of these numbers
1:38:59
throughput all of these numbers everywhere are still float 32s and it's
1:39:01
everywhere are still float 32s and it's
1:39:01
everywhere are still float 32s and it's float 32 numbers that are being shipped
1:39:03
float 32 numbers that are being shipped
1:39:03
float 32 numbers that are being shipped all over the place through the memory
1:39:05
all over the place through the memory
1:39:05
all over the place through the memory system and is just costing us way too
1:39:07
system and is just costing us way too
1:39:07
system and is just costing us way too much time to shuttle around all this
1:39:08
much time to shuttle around all this
1:39:08
much time to shuttle around all this data and so even though we've made the
1:39:10
data and so even though we've made the
1:39:10
data and so even though we've made the multiply itself much faster uh we are
1:39:13
multiply itself much faster uh we are
1:39:13
multiply itself much faster uh we are memory bound and we're not actually
1:39:14
memory bound and we're not actually
1:39:14
memory bound and we're not actually seeing the full benefit uh that would
1:39:16
seeing the full benefit uh that would
1:39:16
seeing the full benefit uh that would come from uh this napkin math here uh
1:39:19
come from uh this napkin math here uh
1:39:19
come from uh this napkin math here uh that said we are getting one a 3X faster
1:39:22
that said we are getting one a 3X faster
1:39:22
that said we are getting one a 3X faster throughput and this is free um single
1:39:26
throughput and this is free um single
1:39:26
throughput and this is free um single line of code in P torch all your
1:39:28
line of code in P torch all your
1:39:28
line of code in P torch all your variables are still float 32 everywhere
1:39:30
variables are still float 32 everywhere
1:39:30
variables are still float 32 everywhere it just runs faster and it's slightly
1:39:32
it just runs faster and it's slightly
1:39:32
it just runs faster and it's slightly more approximate but we're not going to
1:39:34
more approximate but we're not going to
1:39:34
more approximate but we're not going to notice it basically uh so that's
1:39:37
notice it basically uh so that's
1:39:37
notice it basically uh so that's tf32 okay so let's now continue so we've
1:39:41
tf32 okay so let's now continue so we've
1:39:41
tf32 okay so let's now continue so we've exercised this row and um we saw that we
1:39:44
exercised this row and um we saw that we
1:39:44
exercised this row and um we saw that we can crop out some of the Precision
1:39:46
can crop out some of the Precision
1:39:46
can crop out some of the Precision inside the operation itself but we saw
1:39:49
inside the operation itself but we saw
1:39:49
inside the operation itself but we saw that we're still memory bound we're
1:39:50
that we're still memory bound we're
1:39:50
that we're still memory bound we're still moving around all these floats
1:39:51
still moving around all these floats
1:39:52
still moving around all these floats right otherwise and we're paying that
1:39:53
right otherwise and we're paying that
1:39:53
right otherwise and we're paying that cost because of this so let's now
1:39:56
cost because of this so let's now
1:39:56
cost because of this so let's now decrease the amount of stuff that we're
1:39:57
decrease the amount of stuff that we're
1:39:57
decrease the amount of stuff that we're going to be moving around and we're
1:39:59
going to be moving around and we're
1:39:59
going to be moving around and we're going to do that by dropping down to B
1:40:01
going to do that by dropping down to B
1:40:01
going to do that by dropping down to B float 16 so we're only going to be
1:40:04
float 16 so we're only going to be
1:40:04
float 16 so we're only going to be maintaining 16 bits per float and we're
1:40:07
maintaining 16 bits per float and we're
1:40:07
maintaining 16 bits per float and we're going to use the B flat 16 and I'll
1:40:08
going to use the B flat 16 and I'll
1:40:08
going to use the B flat 16 and I'll explain in a bit uh fp16 difference and
1:40:12
explain in a bit uh fp16 difference and
1:40:12
explain in a bit uh fp16 difference and uh we're going to be in this row so when
1:40:14
uh we're going to be in this row so when
1:40:14
uh we're going to be in this row so when we go back to the documentation here for
1:40:16
we go back to the documentation here for
1:40:17
we go back to the documentation here for the a
1:40:18
the a
1:40:18
the a 100 um we see here the precisions that
1:40:23
100 um we see here the precisions that
1:40:23
100 um we see here the precisions that are are available and this is the
1:40:24
are are available and this is the
1:40:25
are are available and this is the original fp32 the tf32 crops out the
1:40:28
original fp32 the tf32 crops out the
1:40:28
original fp32 the tf32 crops out the Precision and then here in
1:40:30
Precision and then here in
1:40:30
Precision and then here in bf16 you see that it is very similar to
1:40:33
bf16 you see that it is very similar to
1:40:33
bf16 you see that it is very similar to tf32 but it's even more aggressive in
1:40:36
tf32 but it's even more aggressive in
1:40:36
tf32 but it's even more aggressive in cropping off of the Precision the
1:40:38
cropping off of the Precision the
1:40:38
cropping off of the Precision the mantisa of this float so the important
1:40:40
mantisa of this float so the important
1:40:40
mantisa of this float so the important thing with B float 16 is that the
1:40:42
thing with B float 16 is that the
1:40:42
thing with B float 16 is that the exponent bits and the sign bit of course
1:40:45
exponent bits and the sign bit of course
1:40:45
exponent bits and the sign bit of course remain unchanged so if you're familiar
1:40:47
remain unchanged so if you're familiar
1:40:47
remain unchanged so if you're familiar with your float numbers and I think this
1:40:49
with your float numbers and I think this
1:40:49
with your float numbers and I think this should should probably be an entire
1:40:52
should should probably be an entire
1:40:52
should should probably be an entire video by itself
1:40:53
video by itself
1:40:53
video by itself the exponent sets the range that you can
1:40:56
the exponent sets the range that you can
1:40:56
the exponent sets the range that you can represent of your numbers and the
1:40:58
represent of your numbers and the
1:40:58
represent of your numbers and the Precision is how much Precision you have
1:41:00
Precision is how much Precision you have
1:41:00
Precision is how much Precision you have for your numbers and so the range of
1:41:04
for your numbers and so the range of
1:41:04
for your numbers and so the range of numbers is identical but we can we have
1:41:07
numbers is identical but we can we have
1:41:07
numbers is identical but we can we have fewer possibilities within that range
1:41:10
fewer possibilities within that range
1:41:10
fewer possibilities within that range because we are truncating the Mena so we
1:41:12
because we are truncating the Mena so we
1:41:12
because we are truncating the Mena so we have less Precision in that
1:41:14
have less Precision in that
1:41:14
have less Precision in that range what that means is that things are
1:41:17
range what that means is that things are
1:41:17
range what that means is that things are actually fairly nice because we have the
1:41:19
actually fairly nice because we have the
1:41:19
actually fairly nice because we have the original range of numbers that are
1:41:21
original range of numbers that are
1:41:21
original range of numbers that are representable in float but we just have
1:41:24
representable in float but we just have
1:41:24
representable in float but we just have less Precision for it and the difference
1:41:27
less Precision for it and the difference
1:41:27
less Precision for it and the difference with fp16 is that they actually touch
1:41:29
with fp16 is that they actually touch
1:41:29
with fp16 is that they actually touch and change the range so fp16 cannot
1:41:32
and change the range so fp16 cannot
1:41:32
and change the range so fp16 cannot represent the full range of fp32 it has
1:41:35
represent the full range of fp32 it has
1:41:35
represent the full range of fp32 it has a reduced range and that's where you
1:41:37
a reduced range and that's where you
1:41:37
a reduced range and that's where you start to actually run into issues
1:41:39
start to actually run into issues
1:41:39
start to actually run into issues because now you need uh these gradient
1:41:41
because now you need uh these gradient
1:41:41
because now you need uh these gradient scalers and things like that and I'm not
1:41:43
scalers and things like that and I'm not
1:41:43
scalers and things like that and I'm not going to go into the detail of that in
1:41:45
going to go into the detail of that in
1:41:45
going to go into the detail of that in this video because that's a whole video
1:41:48
this video because that's a whole video
1:41:48
this video because that's a whole video by itself but fb16 actually historically
1:41:50
by itself but fb16 actually historically
1:41:50
by itself but fb16 actually historically came first that was available in the
1:41:52
came first that was available in the
1:41:52
came first that was available in the Volta series before Amper and so fp16
1:41:56
Volta series before Amper and so fp16
1:41:56
Volta series before Amper and so fp16 came first and everyone started to train
1:41:57
came first and everyone started to train
1:41:58
came first and everyone started to train in fp16 but everyone had to use all
1:42:00
in fp16 but everyone had to use all
1:42:00
in fp16 but everyone had to use all these gradient scaling operations which
1:42:02
these gradient scaling operations which
1:42:02
these gradient scaling operations which are kind of annoying and it's an
1:42:03
are kind of annoying and it's an
1:42:03
are kind of annoying and it's an additional source of state and
1:42:05
additional source of state and
1:42:05
additional source of state and complexity and the reason for that was
1:42:07
complexity and the reason for that was
1:42:07
complexity and the reason for that was because the exponent range was reduced
1:42:09
because the exponent range was reduced
1:42:09
because the exponent range was reduced in fp16 so that's the i e fp16 spec and
1:42:13
in fp16 so that's the i e fp16 spec and
1:42:13
in fp16 so that's the i e fp16 spec and then they came out with bf16 and the
1:42:15
then they came out with bf16 and the
1:42:15
then they came out with bf16 and the Ampere and they made it much simpler
1:42:18
Ampere and they made it much simpler
1:42:18
Ampere and they made it much simpler because we're just truncating manessa we
1:42:20
because we're just truncating manessa we
1:42:20
because we're just truncating manessa we have the exact same range and we do not
1:42:21
have the exact same range and we do not
1:42:21
have the exact same range and we do not need gradient scalers so everything is
1:42:24
need gradient scalers so everything is
1:42:24
need gradient scalers so everything is much much simpler now when we do use
1:42:26
much much simpler now when we do use
1:42:26
much much simpler now when we do use bf16 though we are impacting the numbers
1:42:30
bf16 though we are impacting the numbers
1:42:30
bf16 though we are impacting the numbers that we might be seeing in our pytorch
1:42:32
that we might be seeing in our pytorch
1:42:32
that we might be seeing in our pytorch code these this change is not just local
1:42:35
code these this change is not just local
1:42:35
code these this change is not just local to the operation itself so let's see how
1:42:37
to the operation itself so let's see how
1:42:37
to the operation itself so let's see how that works
1:42:39
that works
1:42:39
that works um there's some documentation here that
1:42:43
um there's some documentation here that
1:42:43
um there's some documentation here that so I think this is probably the best
1:42:44
so I think this is probably the best
1:42:44
so I think this is probably the best best page to explain how to use mixed
1:42:46
best page to explain how to use mixed
1:42:46
best page to explain how to use mixed Precision in pytorch um because there
1:42:49
Precision in pytorch um because there
1:42:49
Precision in pytorch um because there are many other tutorials and so on even
1:42:51
are many other tutorials and so on even
1:42:51
are many other tutorials and so on even within pitor documentation that are a
1:42:53
within pitor documentation that are a
1:42:53
within pitor documentation that are a lot more confusing and so I recommend
1:42:55
lot more confusing and so I recommend
1:42:55
lot more confusing and so I recommend specifically this one because there's
1:42:57
specifically this one because there's
1:42:57
specifically this one because there's five other copies that I would not
1:42:59
five other copies that I would not
1:42:59
five other copies that I would not recommend and then when we come
1:43:02
recommend and then when we come
1:43:02
recommend and then when we come here ignore everything about everything
1:43:04
here ignore everything about everything
1:43:05
here ignore everything about everything ignore everything about gradient
1:43:07
ignore everything about gradient
1:43:07
ignore everything about gradient scalers and only look at torch.
1:43:10
scalers and only look at torch.
1:43:10
scalers and only look at torch. AutoCast and basically also this comes
1:43:13
AutoCast and basically also this comes
1:43:13
AutoCast and basically also this comes to a single line of code at the end so
1:43:15
to a single line of code at the end so
1:43:15
to a single line of code at the end so this is the context manager that we
1:43:18
this is the context manager that we
1:43:18
this is the context manager that we want and we want to use that in our
1:43:21
want and we want to use that in our
1:43:21
want and we want to use that in our Network when you click into the torch.
1:43:25
Network when you click into the torch.
1:43:25
Network when you click into the torch. AutoCast autocasting it has a few more
1:43:28
AutoCast autocasting it has a few more
1:43:28
AutoCast autocasting it has a few more uh a bit more guideline for you so it's
1:43:30
uh a bit more guideline for you so it's
1:43:30
uh a bit more guideline for you so it's telling you do not call B flat 16 on any
1:43:34
telling you do not call B flat 16 on any
1:43:34
telling you do not call B flat 16 on any of your tensors just use AutoCast and
1:43:36
of your tensors just use AutoCast and
1:43:36
of your tensors just use AutoCast and only surround the uh forward pass of the
1:43:39
only surround the uh forward pass of the
1:43:39
only surround the uh forward pass of the model and the loss calculation and
1:43:41
model and the loss calculation and
1:43:41
model and the loss calculation and that's the only two things that you
1:43:43
that's the only two things that you
1:43:43
that's the only two things that you should be surrounding leave the backward
1:43:45
should be surrounding leave the backward
1:43:45
should be surrounding leave the backward and the optimizer step alone so that's
1:43:47
and the optimizer step alone so that's
1:43:47
and the optimizer step alone so that's the guidance that comes from the P team
1:43:49
the guidance that comes from the P team
1:43:49
the guidance that comes from the P team so we're going to follow that guidance
1:43:51
so we're going to follow that guidance
1:43:51
so we're going to follow that guidance and for us because the L calculation is
1:43:53
and for us because the L calculation is
1:43:53
and for us because the L calculation is inside of the model forward pass for us
1:43:56
inside of the model forward pass for us
1:43:56
inside of the model forward pass for us we are going to be doing
1:43:58
we are going to be doing
1:43:58
we are going to be doing this and then we don't want to be using
1:44:00
this and then we don't want to be using
1:44:00
this and then we don't want to be using torch Flo 16 because if we do that we
1:44:02
torch Flo 16 because if we do that we
1:44:02
torch Flo 16 because if we do that we need to start using gradient scalers as
1:44:04
need to start using gradient scalers as
1:44:04
need to start using gradient scalers as well so we are going to be using B float
1:44:06
well so we are going to be using B float
1:44:06
well so we are going to be using B float 16 this is only possible to do an ampere
1:44:09
16 this is only possible to do an ampere
1:44:09
16 this is only possible to do an ampere uh but this means that the changes are
1:44:11
uh but this means that the changes are
1:44:11
uh but this means that the changes are extremely minimal like basically just
1:44:13
extremely minimal like basically just
1:44:13
extremely minimal like basically just this one line of
1:44:14
this one line of
1:44:14
this one line of code um let me first break
1:44:19
code um let me first break
1:44:19
code um let me first break in to here before we actually run this
1:44:22
in to here before we actually run this
1:44:22
in to here before we actually run this so right after logits I'd like to show
1:44:25
so right after logits I'd like to show
1:44:25
so right after logits I'd like to show you that different from the tf32 that we
1:44:28
you that different from the tf32 that we
1:44:28
you that different from the tf32 that we saw this is actually going to impact our
1:44:31
saw this is actually going to impact our
1:44:31
saw this is actually going to impact our tensors
1:44:32
tensors
1:44:32
tensors so this Lis tensor if we now look at
1:44:36
so this Lis tensor if we now look at
1:44:36
so this Lis tensor if we now look at this and we look at the dtype we
1:44:38
this and we look at the dtype we
1:44:38
this and we look at the dtype we suddenly see that this is now B float
1:44:40
suddenly see that this is now B float
1:44:40
suddenly see that this is now B float 16 uh it's not float 32 anymore so our
1:44:43
16 uh it's not float 32 anymore so our
1:44:43
16 uh it's not float 32 anymore so our activations have been changed the
1:44:45
activations have been changed the
1:44:45
activations have been changed the activations tensor is now B FL 16 but
1:44:48
activations tensor is now B FL 16 but
1:44:48
activations tensor is now B FL 16 but not everything has changed so model.
1:44:51
not everything has changed so model.
1:44:51
not everything has changed so model. Transformer
1:44:55
wte uh this is the weight uh token
1:44:57
wte uh this is the weight uh token
1:44:57
wte uh this is the weight uh token embedding table it has a weight inside
1:45:00
embedding table it has a weight inside
1:45:00
embedding table it has a weight inside it and the dtype of this weight this
1:45:02
it and the dtype of this weight this
1:45:02
it and the dtype of this weight this parameter is still torch float 32 so our
1:45:06
parameter is still torch float 32 so our
1:45:06
parameter is still torch float 32 so our parameters seem to still be in float 32
1:45:09
parameters seem to still be in float 32
1:45:09
parameters seem to still be in float 32 but our activations the loits are now in
1:45:11
but our activations the loits are now in
1:45:11
but our activations the loits are now in P 16 so clearly this is why we get the
1:45:14
P 16 so clearly this is why we get the
1:45:14
P 16 so clearly this is why we get the mixed Precision some things pytorch is
1:45:16
mixed Precision some things pytorch is
1:45:16
mixed Precision some things pytorch is keeping inlow 32 some things pytorch is
1:45:19
keeping inlow 32 some things pytorch is
1:45:19
keeping inlow 32 some things pytorch is converting to lower Precision um and
1:45:23
converting to lower Precision um and
1:45:23
converting to lower Precision um and what gets converted at what point is not
1:45:26
what gets converted at what point is not
1:45:26
what gets converted at what point is not super clear I remember scrolling
1:45:30
super clear I remember scrolling
1:45:30
super clear I remember scrolling down is it
1:45:34
here okay I can't find
1:45:37
here okay I can't find
1:45:37
here okay I can't find it I I thought it was here okay there we
1:45:41
it I I thought it was here okay there we
1:45:41
it I I thought it was here okay there we go so there are a few docks on when
1:45:44
go so there are a few docks on when
1:45:44
go so there are a few docks on when you're using this AutoCast what gets
1:45:46
you're using this AutoCast what gets
1:45:46
you're using this AutoCast what gets converted to B FL 16 and and when so for
1:45:49
converted to B FL 16 and and when so for
1:45:49
converted to B FL 16 and and when so for example only these Matrix multiply like
1:45:51
example only these Matrix multiply like
1:45:51
example only these Matrix multiply like operations get converted to float 16 but
1:45:54
operations get converted to float 16 but
1:45:54
operations get converted to float 16 but a lot of operations remain in float 32
1:45:56
a lot of operations remain in float 32
1:45:56
a lot of operations remain in float 32 so in particular a lot of normalizations
1:45:58
so in particular a lot of normalizations
1:45:58
so in particular a lot of normalizations like layer norms and things like that
1:46:00
like layer norms and things like that
1:46:00
like layer norms and things like that not all of those layers might be
1:46:01
not all of those layers might be
1:46:01
not all of those layers might be converted um so only some layers
1:46:04
converted um so only some layers
1:46:05
converted um so only some layers selectively would be running B flat 16
1:46:07
selectively would be running B flat 16
1:46:07
selectively would be running B flat 16 but things like softmax uh layer Norms
1:46:10
but things like softmax uh layer Norms
1:46:10
but things like softmax uh layer Norms uh log um log soft Max so loss function
1:46:14
uh log um log soft Max so loss function
1:46:14
uh log um log soft Max so loss function calculations a lot of those things might
1:46:15
calculations a lot of those things might
1:46:15
calculations a lot of those things might remain in float 32 because they are more
1:46:17
remain in float 32 because they are more
1:46:17
remain in float 32 because they are more susceptible to Precision changes major
1:46:20
susceptible to Precision changes major
1:46:20
susceptible to Precision changes major multiplies are fairly um
1:46:23
multiplies are fairly um
1:46:23
multiplies are fairly um robust to Precision changes uh so some
1:46:26
robust to Precision changes uh so some
1:46:26
robust to Precision changes uh so some parts of the network are um impacted
1:46:29
parts of the network are um impacted
1:46:29
parts of the network are um impacted more or less by the Precision
1:46:31
more or less by the Precision
1:46:31
more or less by the Precision change um so basically only some parts
1:46:34
change um so basically only some parts
1:46:34
change um so basically only some parts of the of the model are running in
1:46:35
of the of the model are running in
1:46:35
of the of the model are running in reduced Precision let's take it for a
1:46:38
reduced Precision let's take it for a
1:46:38
reduced Precision let's take it for a spin and let's actually see what kind of
1:46:41
spin and let's actually see what kind of
1:46:41
spin and let's actually see what kind of improvement we achieve
1:46:48
here okay so we used to be 333
1:46:51
here okay so we used to be 333
1:46:51
here okay so we used to be 333 milliseconds we're now 300
1:46:53
milliseconds we're now 300
1:46:53
milliseconds we're now 300 and we used to be somewhere around
1:46:54
and we used to be somewhere around
1:46:54
and we used to be somewhere around 50,000 tokens per second we're now at 55
1:46:57
50,000 tokens per second we're now at 55
1:46:57
50,000 tokens per second we're now at 55 so we're definitely running faster but
1:46:59
so we're definitely running faster but
1:46:59
so we're definitely running faster but maybe not a lot faster and that's
1:47:02
maybe not a lot faster and that's
1:47:02
maybe not a lot faster and that's because there are still many many
1:47:03
because there are still many many
1:47:03
because there are still many many bottlenecks in our gbt2 we're just
1:47:05
bottlenecks in our gbt2 we're just
1:47:05
bottlenecks in our gbt2 we're just getting started but we have dropped down
1:47:07
getting started but we have dropped down
1:47:07
getting started but we have dropped down the precision as far as we can with my
1:47:09
the precision as far as we can with my
1:47:09
the precision as far as we can with my current GPU which is a100 we're using
1:47:12
current GPU which is a100 we're using
1:47:12
current GPU which is a100 we're using pytorch AutoCast unfortunately I don't
1:47:15
pytorch AutoCast unfortunately I don't
1:47:15
pytorch AutoCast unfortunately I don't actually exactly know what pytorch
1:47:17
actually exactly know what pytorch
1:47:17
actually exactly know what pytorch AutoCast do uh does I don't actually
1:47:19
AutoCast do uh does I don't actually
1:47:19
AutoCast do uh does I don't actually know exactly what's in B flat 16 what's
1:47:22
know exactly what's in B flat 16 what's
1:47:22
know exactly what's in B flat 16 what's in float 32
1:47:23
in float 32
1:47:23
in float 32 we could go in and we could start to
1:47:24
we could go in and we could start to
1:47:24
we could go in and we could start to scrutinize it um but these are the kinds
1:47:27
scrutinize it um but these are the kinds
1:47:27
scrutinize it um but these are the kinds of rules that pytorch has internally and
1:47:29
of rules that pytorch has internally and
1:47:29
of rules that pytorch has internally and unfortunately they don't documented very
1:47:31
unfortunately they don't documented very
1:47:31
unfortunately they don't documented very well uh so we're not going to go into
1:47:34
well uh so we're not going to go into
1:47:34
well uh so we're not going to go into that into in too much detail but for now
1:47:36
that into in too much detail but for now
1:47:36
that into in too much detail but for now we are training in B flow 16 we do not
1:47:38
we are training in B flow 16 we do not
1:47:39
we are training in B flow 16 we do not need a gradient scaler and the reason
1:47:40
need a gradient scaler and the reason
1:47:40
need a gradient scaler and the reason things are running faster is because um
1:47:44
things are running faster is because um
1:47:44
things are running faster is because um we are able to run tensor course in B FL
1:47:47
we are able to run tensor course in B FL
1:47:47
we are able to run tensor course in B FL 16 now that means we are in this row but
1:47:51
16 now that means we are in this row but
1:47:52
16 now that means we are in this row but uh we are also paying in Precision for
1:47:53
uh we are also paying in Precision for
1:47:53
uh we are also paying in Precision for this uh so um we expect slightly less
1:47:57
this uh so um we expect slightly less
1:47:57
this uh so um we expect slightly less accurate results with respect to the
1:47:58
accurate results with respect to the
1:47:58
accurate results with respect to the original fp32 but empirically in many
1:48:01
original fp32 but empirically in many
1:48:01
original fp32 but empirically in many cases this is a worth it uh kind of
1:48:04
cases this is a worth it uh kind of
1:48:04
cases this is a worth it uh kind of tradeoff because it allows you to run
1:48:06
tradeoff because it allows you to run
1:48:06
tradeoff because it allows you to run faster and you could for example train
1:48:07
faster and you could for example train
1:48:07
faster and you could for example train longer and make up for the uh for that
1:48:10
longer and make up for the uh for that
1:48:10
longer and make up for the uh for that Precision decrease so um that's b46 for
1:48:15
Precision decrease so um that's b46 for
1:48:15
Precision decrease so um that's b46 for now okay so as we can see we are
1:48:17
now okay so as we can see we are
1:48:17
now okay so as we can see we are currently at about 300 milliseconds uh
1:48:19
currently at about 300 milliseconds uh
1:48:19
currently at about 300 milliseconds uh per iteration and we're now going to
1:48:21
per iteration and we're now going to
1:48:21
per iteration and we're now going to reach for some really heavy weapons in
1:48:23
reach for some really heavy weapons in
1:48:23
reach for some really heavy weapons in the pie torch Arsenal and in particular
1:48:25
the pie torch Arsenal and in particular
1:48:25
the pie torch Arsenal and in particular we're going to introduce torch. compile
1:48:27
we're going to introduce torch. compile
1:48:27
we're going to introduce torch. compile so torch. compile is really quite
1:48:29
so torch. compile is really quite
1:48:29
so torch. compile is really quite incredible infrastructure from the
1:48:30
incredible infrastructure from the
1:48:31
incredible infrastructure from the pytorch team and it's basically a
1:48:32
pytorch team and it's basically a
1:48:32
pytorch team and it's basically a compiler for neural networks like it's
1:48:35
compiler for neural networks like it's
1:48:35
compiler for neural networks like it's almost like GCC for CN C++ code this is
1:48:38
almost like GCC for CN C++ code this is
1:48:38
almost like GCC for CN C++ code this is just this GCC of neural nuts so came out
1:48:42
just this GCC of neural nuts so came out
1:48:42
just this GCC of neural nuts so came out a while ago and extremely simple to use
1:48:46
a while ago and extremely simple to use
1:48:46
a while ago and extremely simple to use um the way to use torch compile is to do
1:48:48
um the way to use torch compile is to do
1:48:48
um the way to use torch compile is to do this it's a single line of code to
1:48:50
this it's a single line of code to
1:48:50
this it's a single line of code to compile your model and return it now
1:48:54
compile your model and return it now
1:48:54
compile your model and return it now this line of code will cost you
1:48:55
this line of code will cost you
1:48:55
this line of code will cost you compilation time but as you might guess
1:48:57
compilation time but as you might guess
1:48:57
compilation time but as you might guess it's going to make the code a lot faster
1:48:59
it's going to make the code a lot faster
1:48:59
it's going to make the code a lot faster so let's actually run that because this
1:49:01
so let's actually run that because this
1:49:01
so let's actually run that because this will take some time to run but currently
1:49:03
will take some time to run but currently
1:49:03
will take some time to run but currently remember we're at 300 milliseconds and
1:49:05
remember we're at 300 milliseconds and
1:49:05
remember we're at 300 milliseconds and we'll see what happens now while this is
1:49:08
we'll see what happens now while this is
1:49:08
we'll see what happens now while this is running I'd like to explain a little bit
1:49:10
running I'd like to explain a little bit
1:49:10
running I'd like to explain a little bit of what torch. compile does under the
1:49:11
of what torch. compile does under the
1:49:11
of what torch. compile does under the hood uh so feel free to read this page
1:49:14
hood uh so feel free to read this page
1:49:15
hood uh so feel free to read this page of P torch but basically there's no real
1:49:17
of P torch but basically there's no real
1:49:17
of P torch but basically there's no real good reason for you to not use torch
1:49:19
good reason for you to not use torch
1:49:19
good reason for you to not use torch compile in your pie torch I kind of feel
1:49:21
compile in your pie torch I kind of feel
1:49:21
compile in your pie torch I kind of feel like you should be using almost by
1:49:23
like you should be using almost by
1:49:23
like you should be using almost by default if you're not uh unless you're
1:49:25
default if you're not uh unless you're
1:49:25
default if you're not uh unless you're debugging and you want your code to run
1:49:26
debugging and you want your code to run
1:49:26
debugging and you want your code to run really fast and there's one line here in
1:49:29
really fast and there's one line here in
1:49:29
really fast and there's one line here in torch compile that I found that actually
1:49:31
torch compile that I found that actually
1:49:31
torch compile that I found that actually kind of like gets to why this is faster
1:49:33
kind of like gets to why this is faster
1:49:33
kind of like gets to why this is faster speed up mainly comes from reducing
1:49:35
speed up mainly comes from reducing
1:49:35
speed up mainly comes from reducing python overhead and GPU read wrs so let
1:49:38
python overhead and GPU read wrs so let
1:49:38
python overhead and GPU read wrs so let me unpack that a little bit um okay here
1:49:41
me unpack that a little bit um okay here
1:49:41
me unpack that a little bit um okay here we are okay so we went from 300
1:49:43
we are okay so we went from 300
1:49:43
we are okay so we went from 300 milliseconds we're now running at 129
1:49:46
milliseconds we're now running at 129
1:49:46
milliseconds we're now running at 129 milliseconds so this is uh 300 129 about
1:49:51
milliseconds so this is uh 300 129 about
1:49:51
milliseconds so this is uh 300 129 about 2.3x Improvement from a single line of
1:49:53
2.3x Improvement from a single line of
1:49:53
2.3x Improvement from a single line of code in py torch uh so quite incredible
1:49:56
code in py torch uh so quite incredible
1:49:56
code in py torch uh so quite incredible so what is happening what's happening
1:49:57
so what is happening what's happening
1:49:57
so what is happening what's happening under the hood well when you pass the
1:49:59
under the hood well when you pass the
1:49:59
under the hood well when you pass the model to torch
1:50:01
model to torch
1:50:01
model to torch compile what we have here in this NN
1:50:04
compile what we have here in this NN
1:50:04
compile what we have here in this NN module this is really just the
1:50:05
module this is really just the
1:50:05
module this is really just the algorithmic description of what we'd
1:50:08
algorithmic description of what we'd
1:50:08
algorithmic description of what we'd like to happen in our Network and torch
1:50:11
like to happen in our Network and torch
1:50:11
like to happen in our Network and torch compile will analyze the entire thing
1:50:14
compile will analyze the entire thing
1:50:14
compile will analyze the entire thing and it will look at what operations You'
1:50:15
and it will look at what operations You'
1:50:15
and it will look at what operations You' like to use and with the benefit of
1:50:18
like to use and with the benefit of
1:50:18
like to use and with the benefit of knowing exactly what's going to happen
1:50:20
knowing exactly what's going to happen
1:50:20
knowing exactly what's going to happen it doesn't have to run in What's called
1:50:22
it doesn't have to run in What's called
1:50:22
it doesn't have to run in What's called the e mode it doesn't have to just kind
1:50:24
the e mode it doesn't have to just kind
1:50:24
the e mode it doesn't have to just kind of like go layer by layer like the
1:50:26
of like go layer by layer like the
1:50:26
of like go layer by layer like the python interpreter normally would start
1:50:29
python interpreter normally would start
1:50:29
python interpreter normally would start at the
1:50:31
at the
1:50:31
at the forward and the python interpreter will
1:50:33
forward and the python interpreter will
1:50:33
forward and the python interpreter will go okay let's do this operation and then
1:50:36
go okay let's do this operation and then
1:50:36
go okay let's do this operation and then let's do that operation and it kind of
1:50:38
let's do that operation and it kind of
1:50:38
let's do that operation and it kind of materializes all the operations as it
1:50:40
materializes all the operations as it
1:50:40
materializes all the operations as it goes through uh so these um calculations
1:50:43
goes through uh so these um calculations
1:50:43
goes through uh so these um calculations are dispatched and run in this order and
1:50:45
are dispatched and run in this order and
1:50:45
are dispatched and run in this order and the python interpreter and this code
1:50:47
the python interpreter and this code
1:50:47
the python interpreter and this code doesn't know what kind of operations are
1:50:49
doesn't know what kind of operations are
1:50:49
doesn't know what kind of operations are going to happen later but torch compile
1:50:51
going to happen later but torch compile
1:50:51
going to happen later but torch compile sees your entire code at the same time
1:50:53
sees your entire code at the same time
1:50:53
sees your entire code at the same time and it's able to know what operations
1:50:55
and it's able to know what operations
1:50:56
and it's able to know what operations you intend to run and it will kind of
1:50:58
you intend to run and it will kind of
1:50:58
you intend to run and it will kind of optimize that process the first thing it
1:51:00
optimize that process the first thing it
1:51:00
optimize that process the first thing it will do is will it will take out the
1:51:01
will do is will it will take out the
1:51:01
will do is will it will take out the python interpreter from the forward pass
1:51:03
python interpreter from the forward pass
1:51:03
python interpreter from the forward pass entirely and it will kind of compile
1:51:05
entirely and it will kind of compile
1:51:05
entirely and it will kind of compile this entire neural net as a single
1:51:07
this entire neural net as a single
1:51:07
this entire neural net as a single object with no python interpreter
1:51:09
object with no python interpreter
1:51:09
object with no python interpreter involved so it knows exactly what's
1:51:11
involved so it knows exactly what's
1:51:11
involved so it knows exactly what's going to run and we'll just run that and
1:51:12
going to run and we'll just run that and
1:51:12
going to run and we'll just run that and it's all going to be running in
1:51:14
it's all going to be running in
1:51:14
it's all going to be running in efficient
1:51:15
efficient
1:51:15
efficient code uh the second thing that happens is
1:51:18
code uh the second thing that happens is
1:51:18
code uh the second thing that happens is uh this read write that they mentioned
1:51:21
uh this read write that they mentioned
1:51:21
uh this read write that they mentioned very briefly so a good example of that I
1:51:23
very briefly so a good example of that I
1:51:23
very briefly so a good example of that I think is the G nonlinearity that we've
1:51:25
think is the G nonlinearity that we've
1:51:25
think is the G nonlinearity that we've been looking at so here we use the n and
1:51:28
been looking at so here we use the n and
1:51:28
been looking at so here we use the n and G now this here is me uh basically just
1:51:32
G now this here is me uh basically just
1:51:32
G now this here is me uh basically just breaking up the inang Galu uh which you
1:51:35
breaking up the inang Galu uh which you
1:51:35
breaking up the inang Galu uh which you remember has this formula so this here
1:51:37
remember has this formula so this here
1:51:37
remember has this formula so this here is the equivalent implementation to
1:51:39
is the equivalent implementation to
1:51:39
is the equivalent implementation to what's happening inside g algorithmic l
1:51:41
what's happening inside g algorithmic l
1:51:41
what's happening inside g algorithmic l it's
1:51:42
it's
1:51:42
it's identical Now by default if uh we just
1:51:46
identical Now by default if uh we just
1:51:46
identical Now by default if uh we just we using this instead of ending. G here
1:51:48
we using this instead of ending. G here
1:51:48
we using this instead of ending. G here what would happen without torch compile
1:51:51
what would happen without torch compile
1:51:51
what would happen without torch compile well the python interpreter would make
1:51:52
well the python interpreter would make
1:51:52
well the python interpreter would make its way here and then it would be okay
1:51:54
its way here and then it would be okay
1:51:54
its way here and then it would be okay well there's an input well let me first
1:51:58
well there's an input well let me first
1:51:58
well there's an input well let me first let me raise this input to the third
1:51:59
let me raise this input to the third
1:51:59
let me raise this input to the third power and it's going to dispatch a
1:52:01
power and it's going to dispatch a
1:52:01
power and it's going to dispatch a kernel that takes your input and raises
1:52:03
kernel that takes your input and raises
1:52:03
kernel that takes your input and raises it to the third power and that kernel
1:52:05
it to the third power and that kernel
1:52:05
it to the third power and that kernel will run and when this kernel runs what
1:52:08
will run and when this kernel runs what
1:52:08
will run and when this kernel runs what ends up happening is this input is
1:52:11
ends up happening is this input is
1:52:11
ends up happening is this input is stored in the memory of the GPU so
1:52:13
stored in the memory of the GPU so
1:52:13
stored in the memory of the GPU so here's a helpful example of the layout
1:52:16
here's a helpful example of the layout
1:52:16
here's a helpful example of the layout of what's happening right you have your
1:52:18
of what's happening right you have your
1:52:18
of what's happening right you have your CPU this is in every single computer
1:52:21
CPU this is in every single computer
1:52:21
CPU this is in every single computer there's a few cores in there and you
1:52:22
there's a few cores in there and you
1:52:23
there's a few cores in there and you have your uh Ram uh your memory and the
1:52:26
have your uh Ram uh your memory and the
1:52:26
have your uh Ram uh your memory and the CPU can talk to the memory and this is
1:52:28
CPU can talk to the memory and this is
1:52:28
CPU can talk to the memory and this is all well known but now we've added the
1:52:30
all well known but now we've added the
1:52:30
all well known but now we've added the GPU and the GPU is a slightly different
1:52:32
GPU and the GPU is a slightly different
1:52:32
GPU and the GPU is a slightly different architecture of course they can
1:52:33
architecture of course they can
1:52:33
architecture of course they can communicate and it's different in that
1:52:35
communicate and it's different in that
1:52:35
communicate and it's different in that it's got a lot more course than a CPU
1:52:38
it's got a lot more course than a CPU
1:52:38
it's got a lot more course than a CPU all of those cores are individually a
1:52:40
all of those cores are individually a
1:52:40
all of those cores are individually a lot simpler too but it also has memory
1:52:43
lot simpler too but it also has memory
1:52:43
lot simpler too but it also has memory right this high bandwidth memory I'm
1:52:47
right this high bandwidth memory I'm
1:52:47
right this high bandwidth memory I'm sorry if I'm botching it hbm I don't
1:52:49
sorry if I'm botching it hbm I don't
1:52:49
sorry if I'm botching it hbm I don't even know what that stands for I'm just
1:52:51
even know what that stands for I'm just
1:52:51
even know what that stands for I'm just realizing that
1:52:53
realizing that
1:52:53
realizing that but uh this is the memory and it's very
1:52:54
but uh this is the memory and it's very
1:52:54
but uh this is the memory and it's very equivalent to uh RAM basically in the
1:52:58
equivalent to uh RAM basically in the
1:52:58
equivalent to uh RAM basically in the computer and what's happening is that
1:53:00
computer and what's happening is that
1:53:00
computer and what's happening is that input is living in the memory and when
1:53:02
input is living in the memory and when
1:53:02
input is living in the memory and when you do input
1:53:05
you do input
1:53:05
you do input cubed this has to travel to the GPU to
1:53:09
cubed this has to travel to the GPU to
1:53:09
cubed this has to travel to the GPU to the course and to all the caches and
1:53:12
the course and to all the caches and
1:53:12
the course and to all the caches and registers on the actual chip of this
1:53:15
registers on the actual chip of this
1:53:15
registers on the actual chip of this GPU and it has to calculate the all the
1:53:17
GPU and it has to calculate the all the
1:53:17
GPU and it has to calculate the all the elements to the third and then it saves
1:53:19
elements to the third and then it saves
1:53:19
elements to the third and then it saves the result back to the memory and it's
1:53:22
the result back to the memory and it's
1:53:22
the result back to the memory and it's this uh travel time that actually causes
1:53:24
this uh travel time that actually causes
1:53:25
this uh travel time that actually causes a lot of issues so here remember this
1:53:28
a lot of issues so here remember this
1:53:28
a lot of issues so here remember this memory bandwidth we can communicate
1:53:30
memory bandwidth we can communicate
1:53:30
memory bandwidth we can communicate about 2 terabytes per second which is a
1:53:31
about 2 terabytes per second which is a
1:53:31
about 2 terabytes per second which is a lot but also we have to Traverse this
1:53:35
lot but also we have to Traverse this
1:53:35
lot but also we have to Traverse this link and it's very slow so here on the
1:53:37
link and it's very slow so here on the
1:53:37
link and it's very slow so here on the GPU we're on chip and everything is
1:53:39
GPU we're on chip and everything is
1:53:39
GPU we're on chip and everything is super fast within the chip but going to
1:53:41
super fast within the chip but going to
1:53:41
super fast within the chip but going to the memory is extremely expensive takes
1:53:43
the memory is extremely expensive takes
1:53:43
the memory is extremely expensive takes extremely long amount of time and so we
1:53:46
extremely long amount of time and so we
1:53:46
extremely long amount of time and so we load the input do the calculations and
1:53:48
load the input do the calculations and
1:53:48
load the input do the calculations and load back the output and this round trip
1:53:51
load back the output and this round trip
1:53:51
load back the output and this round trip takes a lot of time
1:53:52
takes a lot of time
1:53:53
takes a lot of time and now right after we do that we
1:53:54
and now right after we do that we
1:53:54
and now right after we do that we multiply by this constant so what
1:53:57
multiply by this constant so what
1:53:57
multiply by this constant so what happens then is we dispatch another
1:53:59
happens then is we dispatch another
1:53:59
happens then is we dispatch another kernel and then the result travels back
1:54:02
kernel and then the result travels back
1:54:02
kernel and then the result travels back all the elements get multiplied by a
1:54:03
all the elements get multiplied by a
1:54:03
all the elements get multiplied by a constant and then the results travel
1:54:06
constant and then the results travel
1:54:06
constant and then the results travel back to the memory and then we take the
1:54:09
back to the memory and then we take the
1:54:09
back to the memory and then we take the result and we add back input and so this
1:54:12
result and we add back input and so this
1:54:12
result and we add back input and so this entire thing again travels to the GPU
1:54:15
entire thing again travels to the GPU
1:54:15
entire thing again travels to the GPU adds the inputs and gets written back so
1:54:18
adds the inputs and gets written back so
1:54:18
adds the inputs and gets written back so we're making all these round trips from
1:54:20
we're making all these round trips from
1:54:20
we're making all these round trips from the memory to actually where the comput
1:54:22
the memory to actually where the comput
1:54:22
the memory to actually where the comput happens because all the tensor cores and
1:54:24
happens because all the tensor cores and
1:54:24
happens because all the tensor cores and alus and everything like that is all
1:54:26
alus and everything like that is all
1:54:26
alus and everything like that is all stored on the chip in the GPU so we're
1:54:28
stored on the chip in the GPU so we're
1:54:28
stored on the chip in the GPU so we're doing a ton of round trips and pytorch
1:54:31
doing a ton of round trips and pytorch
1:54:31
doing a ton of round trips and pytorch uh without using torch compile doesn't
1:54:33
uh without using torch compile doesn't
1:54:33
uh without using torch compile doesn't know to optimize this because it doesn't
1:54:36
know to optimize this because it doesn't
1:54:36
know to optimize this because it doesn't know what kind of operations you're
1:54:37
know what kind of operations you're
1:54:37
know what kind of operations you're running later you're just telling it
1:54:39
running later you're just telling it
1:54:39
running later you're just telling it raise the power to the third then do
1:54:41
raise the power to the third then do
1:54:41
raise the power to the third then do this then do that and it will just do
1:54:43
this then do that and it will just do
1:54:43
this then do that and it will just do that in that sequence but torch compile
1:54:45
that in that sequence but torch compile
1:54:45
that in that sequence but torch compile sees your entire code it will come here
1:54:47
sees your entire code it will come here
1:54:47
sees your entire code it will come here and it will realize wait all of these
1:54:48
and it will realize wait all of these
1:54:49
and it will realize wait all of these are elementwise operations and actually
1:54:52
are elementwise operations and actually
1:54:52
are elementwise operations and actually what I'm going to do is I'm going to do
1:54:53
what I'm going to do is I'm going to do
1:54:53
what I'm going to do is I'm going to do a single trip of input to the GPU then
1:54:56
a single trip of input to the GPU then
1:54:56
a single trip of input to the GPU then for every single element I'm going to do
1:54:58
for every single element I'm going to do
1:54:58
for every single element I'm going to do all of these operations while that
1:55:00
all of these operations while that
1:55:00
all of these operations while that memory is on the GPU or chunks of it
1:55:04
memory is on the GPU or chunks of it
1:55:04
memory is on the GPU or chunks of it rather and then I'm going to write back
1:55:06
rather and then I'm going to write back
1:55:06
rather and then I'm going to write back a single time so we're not going to have
1:55:07
a single time so we're not going to have
1:55:07
a single time so we're not going to have these round trips and that's one example
1:55:09
these round trips and that's one example
1:55:09
these round trips and that's one example of what's called kernel fusion and is a
1:55:11
of what's called kernel fusion and is a
1:55:11
of what's called kernel fusion and is a major way in which everything is sped up
1:55:14
major way in which everything is sped up
1:55:14
major way in which everything is sped up so basically if you have your benefit of
1:55:15
so basically if you have your benefit of
1:55:15
so basically if you have your benefit of onet and you know exactly what you're
1:55:17
onet and you know exactly what you're
1:55:17
onet and you know exactly what you're going to compute you can optimize your
1:55:19
going to compute you can optimize your
1:55:19
going to compute you can optimize your round trips to the memory and you're not
1:55:21
round trips to the memory and you're not
1:55:21
round trips to the memory and you're not going to pay the the memory bandwidth
1:55:23
going to pay the the memory bandwidth
1:55:23
going to pay the the memory bandwidth cost and that's fundamentally what makes
1:55:25
cost and that's fundamentally what makes
1:55:25
cost and that's fundamentally what makes some of these operations a lot faster
1:55:27
some of these operations a lot faster
1:55:27
some of these operations a lot faster and what they mean by read writes
1:55:30
and what they mean by read writes
1:55:30
and what they mean by read writes here so let me erase this because we are
1:55:32
here so let me erase this because we are
1:55:32
here so let me erase this because we are not using it and yeah we should be using
1:55:36
not using it and yeah we should be using
1:55:36
not using it and yeah we should be using torch compile and our code is now
1:55:39
torch compile and our code is now
1:55:39
torch compile and our code is now significantly faster and we're doing
1:55:40
significantly faster and we're doing
1:55:40
significantly faster and we're doing about
1:55:41
about
1:55:42
about 125,000 tokens per second but we still
1:55:44
125,000 tokens per second but we still
1:55:45
125,000 tokens per second but we still have a long way to go before we move on
1:55:47
have a long way to go before we move on
1:55:47
have a long way to go before we move on I wanted to supplement the discussion a
1:55:49
I wanted to supplement the discussion a
1:55:49
I wanted to supplement the discussion a little bit with a few more figures uh
1:55:51
little bit with a few more figures uh
1:55:51
little bit with a few more figures uh because this is a complic topic but it's
1:55:53
because this is a complic topic but it's
1:55:53
because this is a complic topic but it's worth understanding on a high level uh
1:55:55
worth understanding on a high level uh
1:55:55
worth understanding on a high level uh what's happening here and I could
1:55:56
what's happening here and I could
1:55:56
what's happening here and I could probably spend an entire video of like
1:55:58
probably spend an entire video of like
1:55:58
probably spend an entire video of like two hours on this but just the preview
1:56:00
two hours on this but just the preview
1:56:00
two hours on this but just the preview of that basically so this chip here that
1:56:03
of that basically so this chip here that
1:56:03
of that basically so this chip here that is uh the GPU this chip is where all the
1:56:06
is uh the GPU this chip is where all the
1:56:06
is uh the GPU this chip is where all the calculations happen mostly but this chip
1:56:09
calculations happen mostly but this chip
1:56:09
calculations happen mostly but this chip also does have some memory in it but
1:56:12
also does have some memory in it but
1:56:12
also does have some memory in it but most of the memory by far is here in the
1:56:15
most of the memory by far is here in the
1:56:15
most of the memory by far is here in the high bandwidth memory hbm and is
1:56:18
high bandwidth memory hbm and is
1:56:18
high bandwidth memory hbm and is connected they're connected um but these
1:56:20
connected they're connected um but these
1:56:20
connected they're connected um but these are two separate chips basically
1:56:23
are two separate chips basically
1:56:23
are two separate chips basically now here this is a zoom in of kind of
1:56:26
now here this is a zoom in of kind of
1:56:26
now here this is a zoom in of kind of this cartoon diagram of a GPU and what
1:56:30
this cartoon diagram of a GPU and what
1:56:30
this cartoon diagram of a GPU and what we're seeing here is number one you see
1:56:31
we're seeing here is number one you see
1:56:31
we're seeing here is number one you see this hbm I I realize it's probably very
1:56:34
this hbm I I realize it's probably very
1:56:34
this hbm I I realize it's probably very small for you but on the sides here it
1:56:35
small for you but on the sides here it
1:56:35
small for you but on the sides here it says hbm and so that that's the links to
1:56:38
says hbm and so that that's the links to
1:56:38
says hbm and so that that's the links to the hbm now the hbm is again off chip on
1:56:42
the hbm now the hbm is again off chip on
1:56:42
the hbm now the hbm is again off chip on the chip there are a large number of
1:56:44
the chip there are a large number of
1:56:45
the chip there are a large number of these streaming
1:56:46
these streaming
1:56:46
these streaming multiprocessors uh every one of these is
1:56:48
multiprocessors uh every one of these is
1:56:48
multiprocessors uh every one of these is an SM there's 120 of them in total and
1:56:51
an SM there's 120 of them in total and
1:56:51
an SM there's 120 of them in total and this is where the a lot of the
1:56:52
this is where the a lot of the
1:56:52
this is where the a lot of the calculations happen and this is a zoom
1:56:54
calculations happen and this is a zoom
1:56:54
calculations happen and this is a zoom in of a single individual as it has
1:56:57
in of a single individual as it has
1:56:57
in of a single individual as it has these four quadrants and see for example
1:56:59
these four quadrants and see for example
1:56:59
these four quadrants and see for example tensor core this is where a lot of the
1:57:00
tensor core this is where a lot of the
1:57:00
tensor core this is where a lot of the Matrix multiply stuff happens but
1:57:02
Matrix multiply stuff happens but
1:57:02
Matrix multiply stuff happens but there's all these other units to do all
1:57:04
there's all these other units to do all
1:57:04
there's all these other units to do all different kinds of calculations for fp64
1:57:07
different kinds of calculations for fp64
1:57:07
different kinds of calculations for fp64 fp32 and for integers and so on now so
1:57:11
fp32 and for integers and so on now so
1:57:11
fp32 and for integers and so on now so we have all this uh logic here to do the
1:57:13
we have all this uh logic here to do the
1:57:13
we have all this uh logic here to do the calculations but in addition to that on
1:57:15
calculations but in addition to that on
1:57:15
calculations but in addition to that on the chip there is memory sprinkled
1:57:17
the chip there is memory sprinkled
1:57:17
the chip there is memory sprinkled throughout the chip so L2 cache is some
1:57:21
throughout the chip so L2 cache is some
1:57:21
throughout the chip so L2 cache is some amount of memory that lives on the chip
1:57:23
amount of memory that lives on the chip
1:57:23
amount of memory that lives on the chip and then on the SMS themselves there's
1:57:25
and then on the SMS themselves there's
1:57:25
and then on the SMS themselves there's L1 cache I realized it's probably very
1:57:28
L1 cache I realized it's probably very
1:57:28
L1 cache I realized it's probably very small for you but this blue bar is L1
1:57:31
small for you but this blue bar is L1
1:57:31
small for you but this blue bar is L1 and there's also registers um and so
1:57:34
and there's also registers um and so
1:57:34
and there's also registers um and so there is memory stored here but the way
1:57:36
there is memory stored here but the way
1:57:36
there is memory stored here but the way this memory is stored is very different
1:57:38
this memory is stored is very different
1:57:38
this memory is stored is very different from the way memory is stored in hbm uh
1:57:41
from the way memory is stored in hbm uh
1:57:41
from the way memory is stored in hbm uh this is a very different implementation
1:57:43
this is a very different implementation
1:57:44
this is a very different implementation uh using um just in terms of like what
1:57:47
uh using um just in terms of like what
1:57:47
uh using um just in terms of like what the Silicon looks like it's a very
1:57:48
the Silicon looks like it's a very
1:57:48
the Silicon looks like it's a very different
1:57:49
different
1:57:49
different implementation um so here you would
1:57:52
implementation um so here you would
1:57:52
implementation um so here you would using transistors and capacitors and
1:57:54
using transistors and capacitors and
1:57:54
using transistors and capacitors and here it's a very different
1:57:55
here it's a very different
1:57:55
here it's a very different implementation uh with SRAM and what
1:57:57
implementation uh with SRAM and what
1:57:57
implementation uh with SRAM and what that looks like but long story short is
1:58:01
that looks like but long story short is
1:58:01
that looks like but long story short is um there is um memory inside the chip
1:58:05
um there is um memory inside the chip
1:58:05
um there is um memory inside the chip but it's not a lot of memory that's the
1:58:07
but it's not a lot of memory that's the
1:58:07
but it's not a lot of memory that's the critical point so this is some C this is
1:58:09
critical point so this is some C this is
1:58:09
critical point so this is some C this is a example diagram of a slightly
1:58:11
a example diagram of a slightly
1:58:11
a example diagram of a slightly different GPU just like here where it
1:58:14
different GPU just like here where it
1:58:14
different GPU just like here where it shows that for example typical numbers
1:58:16
shows that for example typical numbers
1:58:16
shows that for example typical numbers for CPU Dam memory which is this thing
1:58:19
for CPU Dam memory which is this thing
1:58:19
for CPU Dam memory which is this thing here you might have one tab of this
1:58:22
here you might have one tab of this
1:58:22
here you might have one tab of this right but it would be extremely
1:58:23
right but it would be extremely
1:58:23
right but it would be extremely expensive to access especially for a GPU
1:58:25
expensive to access especially for a GPU
1:58:25
expensive to access especially for a GPU you have to go through the CPU here now
1:58:28
you have to go through the CPU here now
1:58:28
you have to go through the CPU here now next we have the hbm so we have tens of
1:58:30
next we have the hbm so we have tens of
1:58:30
next we have the hbm so we have tens of gigabytes of hbm memory on a typical GPU
1:58:33
gigabytes of hbm memory on a typical GPU
1:58:33
gigabytes of hbm memory on a typical GPU here but it's as I mentioned very
1:58:35
here but it's as I mentioned very
1:58:35
here but it's as I mentioned very expensive to access and then on the chip
1:58:38
expensive to access and then on the chip
1:58:38
expensive to access and then on the chip itself everything is extremely fast
1:58:40
itself everything is extremely fast
1:58:40
itself everything is extremely fast within the chip but we only have couple
1:58:42
within the chip but we only have couple
1:58:42
within the chip but we only have couple 10 megabytes of memory collectively
1:58:45
10 megabytes of memory collectively
1:58:45
10 megabytes of memory collectively throughout the Chip And so there's just
1:58:48
throughout the Chip And so there's just
1:58:48
throughout the Chip And so there's just not enough space because the memory is
1:58:50
not enough space because the memory is
1:58:50
not enough space because the memory is very expensive on the chip and so
1:58:52
very expensive on the chip and so
1:58:52
very expensive on the chip and so there's not a lot of it but it is
1:58:53
there's not a lot of it but it is
1:58:53
there's not a lot of it but it is lightning fast to access in relative
1:58:55
lightning fast to access in relative
1:58:55
lightning fast to access in relative terms and so basically whenever we have
1:58:58
terms and so basically whenever we have
1:58:58
terms and so basically whenever we have these kernels um the more accurate
1:59:01
these kernels um the more accurate
1:59:01
these kernels um the more accurate picture of what's Happening Here is that
1:59:03
picture of what's Happening Here is that
1:59:03
picture of what's Happening Here is that we take these inputs which live by
1:59:05
we take these inputs which live by
1:59:05
we take these inputs which live by default on the global memory and now we
1:59:08
default on the global memory and now we
1:59:08
default on the global memory and now we need to perform some calculation so we
1:59:10
need to perform some calculation so we
1:59:10
need to perform some calculation so we start streaming the data from the um
1:59:12
start streaming the data from the um
1:59:12
start streaming the data from the um Global memory to the uh chip we perform
1:59:16
Global memory to the uh chip we perform
1:59:16
Global memory to the uh chip we perform the calculations on the chip and then
1:59:18
the calculations on the chip and then
1:59:18
the calculations on the chip and then stream it back and store it back to the
1:59:19
stream it back and store it back to the
1:59:19
stream it back and store it back to the global memory right and so if we are if
1:59:23
global memory right and so if we are if
1:59:23
global memory right and so if we are if we don't have torch compile we are
1:59:24
we don't have torch compile we are
1:59:24
we don't have torch compile we are streaming the data through the chip
1:59:26
streaming the data through the chip
1:59:26
streaming the data through the chip doing the calculations and saving to the
1:59:27
doing the calculations and saving to the
1:59:27
doing the calculations and saving to the memory and we're doing those round trips
1:59:29
memory and we're doing those round trips
1:59:29
memory and we're doing those round trips many many
1:59:30
many many
1:59:30
many many times but uh if it's torch compiled then
1:59:33
times but uh if it's torch compiled then
1:59:33
times but uh if it's torch compiled then we start streaming the memory as before
1:59:35
we start streaming the memory as before
1:59:35
we start streaming the memory as before but then while we're on the chip we're
1:59:37
but then while we're on the chip we're
1:59:37
but then while we're on the chip we're we're we have a chunk of the uh data
1:59:40
we're we have a chunk of the uh data
1:59:40
we're we have a chunk of the uh data that we're trying to process so that
1:59:42
that we're trying to process so that
1:59:42
that we're trying to process so that chunk now lives on the chip while it's
1:59:44
chunk now lives on the chip while it's
1:59:44
chunk now lives on the chip while it's on the chip it's extremely fast to
1:59:46
on the chip it's extremely fast to
1:59:46
on the chip it's extremely fast to operate on so if we have kernel Fusion
1:59:48
operate on so if we have kernel Fusion
1:59:48
operate on so if we have kernel Fusion we can do all the operations right there
1:59:49
we can do all the operations right there
1:59:49
we can do all the operations right there in an element-wise fashion and those are
1:59:52
in an element-wise fashion and those are
1:59:52
in an element-wise fashion and those are very cheap and then we do a single round
1:59:54
very cheap and then we do a single round
1:59:54
very cheap and then we do a single round trip back to the global memory so
1:59:58
trip back to the global memory so
1:59:58
trip back to the global memory so operator Fusion basically allows you to
1:59:59
operator Fusion basically allows you to
2:00:00
operator Fusion basically allows you to keep your chunk of data on the Chip And
2:00:02
keep your chunk of data on the Chip And
2:00:02
keep your chunk of data on the Chip And do lots of calculations on it before you
2:00:04
do lots of calculations on it before you
2:00:04
do lots of calculations on it before you write it back and that gives huge
2:00:06
write it back and that gives huge
2:00:06
write it back and that gives huge savings and that's why torch compile
2:00:08
savings and that's why torch compile
2:00:09
savings and that's why torch compile ends up being a lot faster or that's one
2:00:11
ends up being a lot faster or that's one
2:00:11
ends up being a lot faster or that's one of the major
2:00:12
of the major
2:00:12
of the major reasons uh so again just a very brief
2:00:14
reasons uh so again just a very brief
2:00:14
reasons uh so again just a very brief intro to the memory hierarchy and
2:00:16
intro to the memory hierarchy and
2:00:16
intro to the memory hierarchy and roughly what torch compile does for you
2:00:19
roughly what torch compile does for you
2:00:19
roughly what torch compile does for you now torch compile is amazing but there
2:00:21
now torch compile is amazing but there
2:00:21
now torch compile is amazing but there are operations torch compile will not
2:00:23
are operations torch compile will not
2:00:23
are operations torch compile will not find and an amazing example of that is
2:00:26
find and an amazing example of that is
2:00:26
find and an amazing example of that is Flash attention to which we turn next so
2:00:28
Flash attention to which we turn next so
2:00:29
Flash attention to which we turn next so flash attention comes from this paper
2:00:30
flash attention comes from this paper
2:00:30
flash attention comes from this paper from uh Stanford in
2:00:33
from uh Stanford in
2:00:33
from uh Stanford in 2022 and it's this incredible algorithm
2:00:36
2022 and it's this incredible algorithm
2:00:36
2022 and it's this incredible algorithm for performing attention so um and
2:00:39
for performing attention so um and
2:00:39
for performing attention so um and running it a lot faster so flash
2:00:41
running it a lot faster so flash
2:00:41
running it a lot faster so flash attention will come here and we will
2:00:44
attention will come here and we will
2:00:44
attention will come here and we will take out these four
2:00:46
take out these four
2:00:46
take out these four lines and Flash attention implements
2:00:48
lines and Flash attention implements
2:00:48
lines and Flash attention implements these four lines really really quickly
2:00:51
these four lines really really quickly
2:00:51
these four lines really really quickly and how does it do that well flash
2:00:53
and how does it do that well flash
2:00:53
and how does it do that well flash attention is a kernel Fusion operation
2:00:57
attention is a kernel Fusion operation
2:00:57
attention is a kernel Fusion operation so you see here we have um in this
2:00:59
so you see here we have um in this
2:00:59
so you see here we have um in this diagram they're showing P torch and you
2:01:02
diagram they're showing P torch and you
2:01:02
diagram they're showing P torch and you have these four operations uh they're
2:01:04
have these four operations uh they're
2:01:04
have these four operations uh they're including Dropout but we are not using
2:01:06
including Dropout but we are not using
2:01:06
including Dropout but we are not using Dropout here so we just have these four
2:01:08
Dropout here so we just have these four
2:01:08
Dropout here so we just have these four lines of code here and instead of those
2:01:11
lines of code here and instead of those
2:01:11
lines of code here and instead of those we are fusing them into a single fused
2:01:13
we are fusing them into a single fused
2:01:13
we are fusing them into a single fused kernel of flash attention so it's an
2:01:16
kernel of flash attention so it's an
2:01:16
kernel of flash attention so it's an it's a it's a kernel Fusion algorithm
2:01:19
it's a it's a kernel Fusion algorithm
2:01:19
it's a it's a kernel Fusion algorithm but it's a kernel Fusion that torch
2:01:20
but it's a kernel Fusion that torch
2:01:20
but it's a kernel Fusion that torch compile cannot find
2:01:22
compile cannot find
2:01:22
compile cannot find and the reason that it cannot find it is
2:01:24
and the reason that it cannot find it is
2:01:24
and the reason that it cannot find it is that it um requires an algorithmic
2:01:26
that it um requires an algorithmic
2:01:26
that it um requires an algorithmic rewrite of how attention is actually
2:01:28
rewrite of how attention is actually
2:01:28
rewrite of how attention is actually implemented here in this case and what's
2:01:31
implemented here in this case and what's
2:01:31
implemented here in this case and what's remarkable about it is that uh flash
2:01:33
remarkable about it is that uh flash
2:01:33
remarkable about it is that uh flash attention actually if you just count the
2:01:35
attention actually if you just count the
2:01:35
attention actually if you just count the number of flops flash attention does
2:01:37
number of flops flash attention does
2:01:37
number of flops flash attention does more flops than this attention here but
2:01:41
more flops than this attention here but
2:01:41
more flops than this attention here but flash attention is actually
2:01:42
flash attention is actually
2:01:42
flash attention is actually significantly faster in fact they site
2:01:45
significantly faster in fact they site
2:01:45
significantly faster in fact they site 7. six times faster potentially and
2:01:48
7. six times faster potentially and
2:01:48
7. six times faster potentially and that's because it is very mindful of the
2:01:51
that's because it is very mindful of the
2:01:51
that's because it is very mindful of the memory hierarchy as I described it just
2:01:53
memory hierarchy as I described it just
2:01:53
memory hierarchy as I described it just now and so it's very mindful about
2:01:55
now and so it's very mindful about
2:01:55
now and so it's very mindful about what's in high bandwidth memory what's
2:01:57
what's in high bandwidth memory what's
2:01:57
what's in high bandwidth memory what's in the shared memory and it is very
2:02:00
in the shared memory and it is very
2:02:00
in the shared memory and it is very careful with how it orchestrates the
2:02:02
careful with how it orchestrates the
2:02:02
careful with how it orchestrates the computation such that we have fewer
2:02:04
computation such that we have fewer
2:02:04
computation such that we have fewer reads and writes to the high bandwidth
2:02:06
reads and writes to the high bandwidth
2:02:06
reads and writes to the high bandwidth memory and so even though we're doing
2:02:08
memory and so even though we're doing
2:02:08
memory and so even though we're doing more flops the expensive part is they
2:02:10
more flops the expensive part is they
2:02:10
more flops the expensive part is they load and store into hbm and that's what
2:02:12
load and store into hbm and that's what
2:02:12
load and store into hbm and that's what they avoid and so in particular they do
2:02:15
they avoid and so in particular they do
2:02:15
they avoid and so in particular they do not ever materialize this end byend
2:02:17
not ever materialize this end byend
2:02:17
not ever materialize this end byend attention Matrix this ATT here a flash
2:02:21
attention Matrix this ATT here a flash
2:02:21
attention Matrix this ATT here a flash attention is designed such that this
2:02:23
attention is designed such that this
2:02:23
attention is designed such that this Matrix never gets materialized at any
2:02:25
Matrix never gets materialized at any
2:02:25
Matrix never gets materialized at any point and it never gets read or written
2:02:28
point and it never gets read or written
2:02:28
point and it never gets read or written to the hbm and this is a very large
2:02:30
to the hbm and this is a very large
2:02:30
to the hbm and this is a very large Matrix right so um because this is where
2:02:32
Matrix right so um because this is where
2:02:32
Matrix right so um because this is where all the queries and keys interact and
2:02:34
all the queries and keys interact and
2:02:34
all the queries and keys interact and we're sort of getting
2:02:36
we're sort of getting
2:02:36
we're sort of getting um for each head for each batch element
2:02:40
um for each head for each batch element
2:02:40
um for each head for each batch element we're getting a t BYT Matrix of
2:02:42
we're getting a t BYT Matrix of
2:02:42
we're getting a t BYT Matrix of attention which is a Million numbers
2:02:45
attention which is a Million numbers
2:02:45
attention which is a Million numbers even for a single head at a single batch
2:02:47
even for a single head at a single batch
2:02:47
even for a single head at a single batch index at like so so basically this is a
2:02:50
index at like so so basically this is a
2:02:50
index at like so so basically this is a ton of memory and and this is never
2:02:52
ton of memory and and this is never
2:02:52
ton of memory and and this is never materialized and the way that this is
2:02:54
materialized and the way that this is
2:02:54
materialized and the way that this is achieved is that basically the
2:02:57
achieved is that basically the
2:02:57
achieved is that basically the fundamental algorithmic rewrite here
2:02:58
fundamental algorithmic rewrite here
2:02:58
fundamental algorithmic rewrite here relies on this online softmax trick
2:03:02
relies on this online softmax trick
2:03:02
relies on this online softmax trick which was proposed previously and I'll
2:03:03
which was proposed previously and I'll
2:03:03
which was proposed previously and I'll show you the paper in a bit and the
2:03:05
show you the paper in a bit and the
2:03:05
show you the paper in a bit and the online softmax trick coming from a
2:03:07
online softmax trick coming from a
2:03:07
online softmax trick coming from a previous paper um shows how you can
2:03:10
previous paper um shows how you can
2:03:10
previous paper um shows how you can incrementally evaluate a soft Max
2:03:14
incrementally evaluate a soft Max
2:03:14
incrementally evaluate a soft Max without having to sort of realize all of
2:03:16
without having to sort of realize all of
2:03:16
without having to sort of realize all of the inputs to the softmax to do the
2:03:17
the inputs to the softmax to do the
2:03:18
the inputs to the softmax to do the normalization and you do that by having
2:03:19
normalization and you do that by having
2:03:19
normalization and you do that by having these intermediate variables M and L and
2:03:22
these intermediate variables M and L and
2:03:22
these intermediate variables M and L and there's an update to them that allows
2:03:23
there's an update to them that allows
2:03:24
there's an update to them that allows you to evaluate the softmax in an online
2:03:26
you to evaluate the softmax in an online
2:03:26
you to evaluate the softmax in an online manner um now flash attention actually
2:03:30
manner um now flash attention actually
2:03:30
manner um now flash attention actually so recently flash attention 2 came out
2:03:32
so recently flash attention 2 came out
2:03:32
so recently flash attention 2 came out as well so I have that paper up here as
2:03:34
as well so I have that paper up here as
2:03:34
as well so I have that paper up here as well uh that has additional gains to how
2:03:36
well uh that has additional gains to how
2:03:36
well uh that has additional gains to how it calculates flash attention and the
2:03:38
it calculates flash attention and the
2:03:38
it calculates flash attention and the original paper that this is based on
2:03:40
original paper that this is based on
2:03:40
original paper that this is based on basically is this online normalizer
2:03:42
basically is this online normalizer
2:03:42
basically is this online normalizer calculation for softmax and remarkably
2:03:45
calculation for softmax and remarkably
2:03:45
calculation for softmax and remarkably it came out of Nvidia and it came out of
2:03:46
it came out of Nvidia and it came out of
2:03:46
it came out of Nvidia and it came out of it like really early 2018 so this is 4
2:03:50
it like really early 2018 so this is 4
2:03:50
it like really early 2018 so this is 4 years before flash attention
2:03:52
years before flash attention
2:03:52
years before flash attention and this paper says that we propose a
2:03:55
and this paper says that we propose a
2:03:55
and this paper says that we propose a way to compute the classical softmax
2:03:57
way to compute the classical softmax
2:03:57
way to compute the classical softmax with fewer memory accesses and
2:03:59
with fewer memory accesses and
2:03:59
with fewer memory accesses and hypothesize that this reduction in
2:04:00
hypothesize that this reduction in
2:04:00
hypothesize that this reduction in memory accesses should improve softmax
2:04:02
memory accesses should improve softmax
2:04:02
memory accesses should improve softmax performance on actual hardware and so
2:04:05
performance on actual hardware and so
2:04:05
performance on actual hardware and so they are extremely correct in this
2:04:08
they are extremely correct in this
2:04:08
they are extremely correct in this hypothesis but it's really fascinating
2:04:10
hypothesis but it's really fascinating
2:04:10
hypothesis but it's really fascinating to me that they're from Nvidia and that
2:04:11
to me that they're from Nvidia and that
2:04:12
to me that they're from Nvidia and that they had this realization but they
2:04:13
they had this realization but they
2:04:13
they had this realization but they didn't actually take it to the actual
2:04:15
didn't actually take it to the actual
2:04:15
didn't actually take it to the actual flash attention that had to come four
2:04:17
flash attention that had to come four
2:04:18
flash attention that had to come four years later from Stanford so I don't
2:04:20
years later from Stanford so I don't
2:04:20
years later from Stanford so I don't fully understand the historical how this
2:04:22
fully understand the historical how this
2:04:22
fully understand the historical how this happened historically um but they do
2:04:24
happened historically um but they do
2:04:24
happened historically um but they do basically propose this online update to
2:04:26
basically propose this online update to
2:04:26
basically propose this online update to the softmax uh right here and this is
2:04:29
the softmax uh right here and this is
2:04:29
the softmax uh right here and this is fundamentally what they reuse here to
2:04:31
fundamentally what they reuse here to
2:04:31
fundamentally what they reuse here to calculate the softmax in a streaming
2:04:33
calculate the softmax in a streaming
2:04:33
calculate the softmax in a streaming Manner and then they realize they can
2:04:35
Manner and then they realize they can
2:04:35
Manner and then they realize they can actually fuse all the other operations
2:04:37
actually fuse all the other operations
2:04:37
actually fuse all the other operations with the online sofx calculation into a
2:04:40
with the online sofx calculation into a
2:04:40
with the online sofx calculation into a single fused kernel flash attention and
2:04:42
single fused kernel flash attention and
2:04:42
single fused kernel flash attention and that's what we are about to use so great
2:04:45
that's what we are about to use so great
2:04:45
that's what we are about to use so great example I think of being aware of um
2:04:47
example I think of being aware of um
2:04:47
example I think of being aware of um memory hierarchy the fact that flops
2:04:49
memory hierarchy the fact that flops
2:04:49
memory hierarchy the fact that flops don't matter uh the entire memory access
2:04:52
don't matter uh the entire memory access
2:04:52
don't matter uh the entire memory access pattern matters and that torch compile
2:04:54
pattern matters and that torch compile
2:04:54
pattern matters and that torch compile is amazing but there are many
2:04:55
is amazing but there are many
2:04:55
is amazing but there are many optimizations that are still available
2:04:57
optimizations that are still available
2:04:57
optimizations that are still available to us that potentially torch compile
2:04:59
to us that potentially torch compile
2:04:59
to us that potentially torch compile cannot find maybe maybe one day it could
2:05:01
cannot find maybe maybe one day it could
2:05:01
cannot find maybe maybe one day it could but right now it seems like a lot to ask
2:05:04
but right now it seems like a lot to ask
2:05:04
but right now it seems like a lot to ask so here's what we're going to do we're
2:05:05
so here's what we're going to do we're
2:05:05
so here's what we're going to do we're going to use Flash attention and the way
2:05:09
going to use Flash attention and the way
2:05:09
going to use Flash attention and the way to do that basically in pytorch is we
2:05:11
to do that basically in pytorch is we
2:05:11
to do that basically in pytorch is we are going to comment out these four
2:05:13
are going to comment out these four
2:05:14
are going to comment out these four lines and we're going to replace them
2:05:15
lines and we're going to replace them
2:05:15
lines and we're going to replace them with a single line and here we are
2:05:18
with a single line and here we are
2:05:18
with a single line and here we are calling this compound operation in
2:05:20
calling this compound operation in
2:05:20
calling this compound operation in pytorch called scale that product
2:05:22
pytorch called scale that product
2:05:22
pytorch called scale that product attention and uh pytorch will call flash
2:05:27
attention and uh pytorch will call flash
2:05:27
attention and uh pytorch will call flash attention when you use it in this way
2:05:30
attention when you use it in this way
2:05:31
attention when you use it in this way I'm not actually 100% sure why torch
2:05:32
I'm not actually 100% sure why torch
2:05:32
I'm not actually 100% sure why torch compile doesn't realize that these four
2:05:34
compile doesn't realize that these four
2:05:34
compile doesn't realize that these four lines should just call flash attention
2:05:36
lines should just call flash attention
2:05:36
lines should just call flash attention in this exact way we have to do it again
2:05:38
in this exact way we have to do it again
2:05:38
in this exact way we have to do it again for it which in my opinion is a little
2:05:40
for it which in my opinion is a little
2:05:40
for it which in my opinion is a little bit odd but um here we are so you have
2:05:46
bit odd but um here we are so you have
2:05:46
bit odd but um here we are so you have to use this compound up and uh let's
2:05:48
to use this compound up and uh let's
2:05:49
to use this compound up and uh let's wait for a few moments before torch comp
2:05:51
wait for a few moments before torch comp
2:05:51
wait for a few moments before torch comp compile gets around to it and then let's
2:05:53
compile gets around to it and then let's
2:05:53
compile gets around to it and then let's remember that we achieved 6.05 661 I
2:05:58
remember that we achieved 6.05 661 I
2:05:58
remember that we achieved 6.05 661 I have it here that's the loss we were
2:06:00
have it here that's the loss we were
2:06:00
have it here that's the loss we were expecting to see and we took 130
2:06:03
expecting to see and we took 130
2:06:03
expecting to see and we took 130 milliseconds uh before this change so
2:06:05
milliseconds uh before this change so
2:06:05
milliseconds uh before this change so we're expecting to see the exact same
2:06:07
we're expecting to see the exact same
2:06:07
we're expecting to see the exact same result by iteration 49 but we expect to
2:06:10
result by iteration 49 but we expect to
2:06:10
result by iteration 49 but we expect to see faster runtime because Flash
2:06:13
see faster runtime because Flash
2:06:13
see faster runtime because Flash attention is just a an algorithmic
2:06:14
attention is just a an algorithmic
2:06:14
attention is just a an algorithmic rewrite and it's a faster kernel but it
2:06:16
rewrite and it's a faster kernel but it
2:06:16
rewrite and it's a faster kernel but it doesn't actually change any of the
2:06:17
doesn't actually change any of the
2:06:17
doesn't actually change any of the computation and we should have the exact
2:06:19
computation and we should have the exact
2:06:19
computation and we should have the exact same optimization so okay so we're a lot
2:06:21
same optimization so okay so we're a lot
2:06:21
same optimization so okay so we're a lot faster we're at about 95 milliseconds
2:06:24
faster we're at about 95 milliseconds
2:06:24
faster we're at about 95 milliseconds and we achiev
2:06:27
and we achiev
2:06:28
and we achiev 6.58 okay so they're basically identical
2:06:31
6.58 okay so they're basically identical
2:06:31
6.58 okay so they're basically identical up to a floating Point fudge Factor so
2:06:34
up to a floating Point fudge Factor so
2:06:34
up to a floating Point fudge Factor so it's the identical computation but it's
2:06:36
it's the identical computation but it's
2:06:36
it's the identical computation but it's significantly faster going from 130 to
2:06:39
significantly faster going from 130 to
2:06:39
significantly faster going from 130 to roughly 90
2:06:40
roughly 90
2:06:40
roughly 90 96 and so this is um 96 divide
2:06:44
96 and so this is um 96 divide
2:06:44
96 and so this is um 96 divide 130ish so this is maybe 27 is%
2:06:48
130ish so this is maybe 27 is%
2:06:48
130ish so this is maybe 27 is% Improvement um so uh really interesting
2:06:52
Improvement um so uh really interesting
2:06:52
Improvement um so uh really interesting and that is Flash retention okay we are
2:06:54
and that is Flash retention okay we are
2:06:54
and that is Flash retention okay we are now getting to one of my favorite
2:06:56
now getting to one of my favorite
2:06:57
now getting to one of my favorite optimizations and it is simultaneously
2:06:59
optimizations and it is simultaneously
2:06:59
optimizations and it is simultaneously the dumbest and the most brilliant
2:07:02
the dumbest and the most brilliant
2:07:02
the dumbest and the most brilliant optimization and it's always a little
2:07:03
optimization and it's always a little
2:07:03
optimization and it's always a little bit surprising to me um anyway so
2:07:06
bit surprising to me um anyway so
2:07:06
bit surprising to me um anyway so basically I mentioned a few minutes ago
2:07:08
basically I mentioned a few minutes ago
2:07:08
basically I mentioned a few minutes ago that there are some numbers that are
2:07:10
that there are some numbers that are
2:07:10
that there are some numbers that are nice and some numbers that are ugly so
2:07:13
nice and some numbers that are ugly so
2:07:13
nice and some numbers that are ugly so 64 is a beautiful nice number 128 is
2:07:17
64 is a beautiful nice number 128 is
2:07:17
64 is a beautiful nice number 128 is even nicer 256 is beautiful what makes
2:07:20
even nicer 256 is beautiful what makes
2:07:20
even nicer 256 is beautiful what makes these numbers beautiful is that there
2:07:21
these numbers beautiful is that there
2:07:21
these numbers beautiful is that there are many powers of two inside them you
2:07:23
are many powers of two inside them you
2:07:23
are many powers of two inside them you can divide by two many times and uh
2:07:26
can divide by two many times and uh
2:07:26
can divide by two many times and uh examples of ugly numbers are like 13 and
2:07:28
examples of ugly numbers are like 13 and
2:07:28
examples of ugly numbers are like 13 and 17 and something like that prime numbers
2:07:30
17 and something like that prime numbers
2:07:30
17 and something like that prime numbers numbers that are not even and so on and
2:07:32
numbers that are not even and so on and
2:07:32
numbers that are not even and so on and so pretty much you always want to use
2:07:34
so pretty much you always want to use
2:07:34
so pretty much you always want to use nice numbers in all of your code that
2:07:36
nice numbers in all of your code that
2:07:36
nice numbers in all of your code that deals with neural networks or Cuda
2:07:38
deals with neural networks or Cuda
2:07:38
deals with neural networks or Cuda because everything in Cuda Works in sort
2:07:40
because everything in Cuda Works in sort
2:07:40
because everything in Cuda Works in sort of like powers of two and lots of
2:07:42
of like powers of two and lots of
2:07:42
of like powers of two and lots of kernels are written in terms of powers
2:07:45
kernels are written in terms of powers
2:07:45
kernels are written in terms of powers of Two And there are lots of blocks of
2:07:46
of Two And there are lots of blocks of
2:07:47
of Two And there are lots of blocks of sizes 16 and uh 64 and so on so
2:07:50
sizes 16 and uh 64 and so on so
2:07:50
sizes 16 and uh 64 and so on so everything is written in those terms and
2:07:52
everything is written in those terms and
2:07:52
everything is written in those terms and you always have special case handling
2:07:54
you always have special case handling
2:07:54
you always have special case handling for all kinds of uh logic that U when
2:07:57
for all kinds of uh logic that U when
2:07:57
for all kinds of uh logic that U when your inputs are not made of nice numbers
2:08:00
your inputs are not made of nice numbers
2:08:00
your inputs are not made of nice numbers so let's see what that looks like
2:08:01
so let's see what that looks like
2:08:01
so let's see what that looks like basically scan your code and look for
2:08:03
basically scan your code and look for
2:08:03
basically scan your code and look for ugly numbers is roughly theistic so
2:08:06
ugly numbers is roughly theistic so
2:08:06
ugly numbers is roughly theistic so three times is kind of ugly um I'm not
2:08:10
three times is kind of ugly um I'm not
2:08:10
three times is kind of ugly um I'm not 100% sure maybe this can be improved but
2:08:12
100% sure maybe this can be improved but
2:08:12
100% sure maybe this can be improved but this is uh this is ugly and not
2:08:15
this is uh this is ugly and not
2:08:15
this is uh this is ugly and not ideal um four times is nice so that's uh
2:08:20
ideal um four times is nice so that's uh
2:08:20
ideal um four times is nice so that's uh that's nice
2:08:22
that's nice
2:08:22
that's nice 1024 is very nice that's a power of two
2:08:25
1024 is very nice that's a power of two
2:08:25
1024 is very nice that's a power of two 12 is a little bit suspicious um not too
2:08:28
12 is a little bit suspicious um not too
2:08:28
12 is a little bit suspicious um not too many powers of two 768 is great 50, 257
2:08:32
many powers of two 768 is great 50, 257
2:08:32
many powers of two 768 is great 50, 257 is a really really ugly number um it's
2:08:36
is a really really ugly number um it's
2:08:36
is a really really ugly number um it's first of all it's odd so uh and there's
2:08:38
first of all it's odd so uh and there's
2:08:38
first of all it's odd so uh and there's no not too many powers of two in there
2:08:40
no not too many powers of two in there
2:08:40
no not too many powers of two in there so this is a very ugly number and it's
2:08:43
so this is a very ugly number and it's
2:08:43
so this is a very ugly number and it's highly suspicious and then when we
2:08:45
highly suspicious and then when we
2:08:45
highly suspicious and then when we scroll down all these numbers are nice
2:08:48
scroll down all these numbers are nice
2:08:48
scroll down all these numbers are nice and then here we have mostly nice
2:08:50
and then here we have mostly nice
2:08:50
and then here we have mostly nice numbers except for 25 so in this
2:08:53
numbers except for 25 so in this
2:08:53
numbers except for 25 so in this configuration of gpt2 XL a number of
2:08:55
configuration of gpt2 XL a number of
2:08:55
configuration of gpt2 XL a number of heads is 25 uh that's a really ugly
2:08:57
heads is 25 uh that's a really ugly
2:08:57
heads is 25 uh that's a really ugly number that's an odd number and um
2:09:00
number that's an odd number and um
2:09:00
number that's an odd number and um actually this did cause a lot of
2:09:01
actually this did cause a lot of
2:09:01
actually this did cause a lot of headaches for us recently when we're
2:09:02
headaches for us recently when we're
2:09:02
headaches for us recently when we're trying to optimize some kernels uh to
2:09:04
trying to optimize some kernels uh to
2:09:04
trying to optimize some kernels uh to run this fast um and required a bunch of
2:09:07
run this fast um and required a bunch of
2:09:07
run this fast um and required a bunch of special case handling so basically these
2:09:10
special case handling so basically these
2:09:10
special case handling so basically these numbers are we have some ugly numbers
2:09:12
numbers are we have some ugly numbers
2:09:12
numbers are we have some ugly numbers and some of them are easier to fix than
2:09:13
and some of them are easier to fix than
2:09:13
and some of them are easier to fix than others and in particular the voap size
2:09:15
others and in particular the voap size
2:09:15
others and in particular the voap size being 50257 that's a very ugly number
2:09:18
being 50257 that's a very ugly number
2:09:18
being 50257 that's a very ugly number very suspicious and we want to fix it
2:09:20
very suspicious and we want to fix it
2:09:20
very suspicious and we want to fix it now when you when you fix these things
2:09:23
now when you when you fix these things
2:09:23
now when you when you fix these things uh one of the easy ways to do that is
2:09:24
uh one of the easy ways to do that is
2:09:24
uh one of the easy ways to do that is you basically um increase the number
2:09:27
you basically um increase the number
2:09:27
you basically um increase the number until it's the nearest power of two that
2:09:29
until it's the nearest power of two that
2:09:29
until it's the nearest power of two that you like so here's a much nicer number
2:09:32
you like so here's a much nicer number
2:09:32
you like so here's a much nicer number it's
2:09:33
it's
2:09:33
it's 50304 and why is that because 50304 can
2:09:37
50304 and why is that because 50304 can
2:09:37
50304 and why is that because 50304 can be divided by 8 or by 16 or by 32
2:09:43
be divided by 8 or by 16 or by 32
2:09:43
be divided by 8 or by 16 or by 32 64 it can even be divided by 128 I think
2:09:46
64 it can even be divided by 128 I think
2:09:46
64 it can even be divided by 128 I think yeah so it's a very nice number um so
2:09:49
yeah so it's a very nice number um so
2:09:49
yeah so it's a very nice number um so what we're going to do here is the GPT
2:09:51
what we're going to do here is the GPT
2:09:51
what we're going to do here is the GPT config and you see that we initialized B
2:09:53
config and you see that we initialized B
2:09:53
config and you see that we initialized B cap size to
2:09:54
cap size to
2:09:54
cap size to 50257 Let's override just
2:09:58
50257 Let's override just
2:09:58
50257 Let's override just that um element to be
2:10:01
50304 okay so everything else stays the
2:10:05
50304 okay so everything else stays the
2:10:05
50304 okay so everything else stays the same we're just increasing our
2:10:06
same we're just increasing our
2:10:06
same we're just increasing our vocabulary size so we're adding it's
2:10:09
vocabulary size so we're adding it's
2:10:09
vocabulary size so we're adding it's almost like we're adding fake tokens uh
2:10:12
almost like we're adding fake tokens uh
2:10:12
almost like we're adding fake tokens uh so that book up size has powers of two
2:10:14
so that book up size has powers of two
2:10:14
so that book up size has powers of two inside it now actually what I'm doing
2:10:16
inside it now actually what I'm doing
2:10:16
inside it now actually what I'm doing here by the way is I'm increasing the
2:10:18
here by the way is I'm increasing the
2:10:18
here by the way is I'm increasing the amount of computation that our network
2:10:19
amount of computation that our network
2:10:19
amount of computation that our network will be doing if you just count the the
2:10:21
will be doing if you just count the the
2:10:21
will be doing if you just count the the flops on like do the math of how many
2:10:23
flops on like do the math of how many
2:10:23
flops on like do the math of how many flops we're doing we're going to be
2:10:25
flops we're doing we're going to be
2:10:25
flops we're doing we're going to be doing more flops and we still have to
2:10:27
doing more flops and we still have to
2:10:27
doing more flops and we still have to think through whether this doesn't break
2:10:30
think through whether this doesn't break
2:10:30
think through whether this doesn't break anything but if I just run this uh let's
2:10:33
anything but if I just run this uh let's
2:10:33
anything but if I just run this uh let's see what we get uh currently this ran in
2:10:35
see what we get uh currently this ran in
2:10:35
see what we get uh currently this ran in maybe
2:10:38
maybe
2:10:38
maybe 96.5 milliseconds per step I'm just kind
2:10:41
96.5 milliseconds per step I'm just kind
2:10:41
96.5 milliseconds per step I'm just kind of like eyeballing it and let's see what
2:10:43
of like eyeballing it and let's see what
2:10:43
of like eyeballing it and let's see what kind of a result we're going to
2:10:46
kind of a result we're going to
2:10:46
kind of a result we're going to get uh while this is compiling let's
2:10:49
get uh while this is compiling let's
2:10:49
get uh while this is compiling let's think through whether our code actually
2:10:51
think through whether our code actually
2:10:51
think through whether our code actually works okay when we increase the vocap
2:10:53
works okay when we increase the vocap
2:10:53
works okay when we increase the vocap size like this let's look at where vocap
2:10:55
size like this let's look at where vocap
2:10:55
size like this let's look at where vocap size is actually
2:10:57
size is actually
2:10:57
size is actually used so we swing up to the inet and we
2:11:00
used so we swing up to the inet and we
2:11:00
used so we swing up to the inet and we see that it's used inside the embedding
2:11:01
see that it's used inside the embedding
2:11:01
see that it's used inside the embedding table of course so all the way at the
2:11:03
table of course so all the way at the
2:11:03
table of course so all the way at the bottom of the Transformer and it's used
2:11:05
bottom of the Transformer and it's used
2:11:05
bottom of the Transformer and it's used at the classifier layer all the way at
2:11:06
at the classifier layer all the way at
2:11:06
at the classifier layer all the way at the top of the Transformer so in two
2:11:08
the top of the Transformer so in two
2:11:08
the top of the Transformer so in two places and let's take a look and we're
2:11:11
places and let's take a look and we're
2:11:11
places and let's take a look and we're running at 93 so 93 milliseconds instead
2:11:14
running at 93 so 93 milliseconds instead
2:11:14
running at 93 so 93 milliseconds instead of
2:11:15
of
2:11:15
of 96.5 so we are seeing a roughly yeah 4%
2:11:19
96.5 so we are seeing a roughly yeah 4%
2:11:19
96.5 so we are seeing a roughly yeah 4% Improvement here uh by doing more
2:11:22
Improvement here uh by doing more
2:11:22
Improvement here uh by doing more calculations and the reason for this is
2:11:25
calculations and the reason for this is
2:11:25
calculations and the reason for this is we fixed we've made an ugly number into
2:11:28
we fixed we've made an ugly number into
2:11:28
we fixed we've made an ugly number into a nice number let's I'm going to come
2:11:30
a nice number let's I'm going to come
2:11:30
a nice number let's I'm going to come into the explanation for that a little
2:11:32
into the explanation for that a little
2:11:32
into the explanation for that a little bit again but for now let's just
2:11:34
bit again but for now let's just
2:11:34
bit again but for now let's just convince ourselves that we're not
2:11:35
convince ourselves that we're not
2:11:35
convince ourselves that we're not breaking anything when we do this so
2:11:36
breaking anything when we do this so
2:11:36
breaking anything when we do this so first of all we've made the the wte the
2:11:39
first of all we've made the the wte the
2:11:39
first of all we've made the the wte the embedding table for the tokens we've
2:11:41
embedding table for the tokens we've
2:11:41
embedding table for the tokens we've made it larger it's almost like we
2:11:43
made it larger it's almost like we
2:11:43
made it larger it's almost like we introduced more tokens at the bottom and
2:11:46
introduced more tokens at the bottom and
2:11:46
introduced more tokens at the bottom and these tokens are never used because the
2:11:48
these tokens are never used because the
2:11:48
these tokens are never used because the gbt tokenizer only has tokens up to
2:11:50
gbt tokenizer only has tokens up to
2:11:50
gbt tokenizer only has tokens up to $50,000
2:11:51
$50,000
2:11:51
$50,000 256 and so we'll never index into the
2:11:55
256 and so we'll never index into the
2:11:55
256 and so we'll never index into the rows that we've added so we're wasting a
2:11:57
rows that we've added so we're wasting a
2:11:57
rows that we've added so we're wasting a little bit of space here by creating
2:11:59
little bit of space here by creating
2:11:59
little bit of space here by creating memory that's never going to be accessed
2:12:01
memory that's never going to be accessed
2:12:01
memory that's never going to be accessed never going to be used Etc now that's
2:12:03
never going to be used Etc now that's
2:12:03
never going to be used Etc now that's not fully correct because this wte
2:12:06
not fully correct because this wte
2:12:06
not fully correct because this wte weight ends up being shared and ends up
2:12:08
weight ends up being shared and ends up
2:12:08
weight ends up being shared and ends up being used in the classifier here at the
2:12:10
being used in the classifier here at the
2:12:10
being used in the classifier here at the end so what is that doing to the
2:12:12
end so what is that doing to the
2:12:13
end so what is that doing to the classifier right here well what what
2:12:15
classifier right here well what what
2:12:15
classifier right here well what what that's doing is we're predicting
2:12:16
that's doing is we're predicting
2:12:16
that's doing is we're predicting additional Dimensions at the classifier
2:12:18
additional Dimensions at the classifier
2:12:18
additional Dimensions at the classifier now and we're predicting probabilities
2:12:20
now and we're predicting probabilities
2:12:20
now and we're predicting probabilities for tokens that will of course never be
2:12:21
for tokens that will of course never be
2:12:21
for tokens that will of course never be present in the training set um and so
2:12:25
present in the training set um and so
2:12:25
present in the training set um and so therefore the network has to learn that
2:12:27
therefore the network has to learn that
2:12:27
therefore the network has to learn that these probabilities uh have to be driven
2:12:29
these probabilities uh have to be driven
2:12:29
these probabilities uh have to be driven to zero and so the logits that the
2:12:31
to zero and so the logits that the
2:12:31
to zero and so the logits that the network produces have to drive those
2:12:33
network produces have to drive those
2:12:33
network produces have to drive those dimensions of the output to negative
2:12:35
dimensions of the output to negative
2:12:35
dimensions of the output to negative Infinity but it that's no different from
2:12:38
Infinity but it that's no different from
2:12:38
Infinity but it that's no different from all the other tokens that are already in
2:12:39
all the other tokens that are already in
2:12:39
all the other tokens that are already in our data set um or rather that are not
2:12:42
our data set um or rather that are not
2:12:42
our data set um or rather that are not in our data set so Shakespeare only
2:12:45
in our data set so Shakespeare only
2:12:45
in our data set so Shakespeare only probably uses let's say a th000 tokens
2:12:46
probably uses let's say a th000 tokens
2:12:46
probably uses let's say a th000 tokens out of 50,000 to 57 tokens so most of
2:12:49
out of 50,000 to 57 tokens so most of
2:12:49
out of 50,000 to 57 tokens so most of the tokens are already being driven to
2:12:51
the tokens are already being driven to
2:12:51
the tokens are already being driven to zero probability by the optimization we'
2:12:53
zero probability by the optimization we'
2:12:53
zero probability by the optimization we' just introduced a few more tokens now
2:12:55
just introduced a few more tokens now
2:12:55
just introduced a few more tokens now that in a similar manner will never be
2:12:57
that in a similar manner will never be
2:12:57
that in a similar manner will never be used and have to be driven to zero in
2:12:59
used and have to be driven to zero in
2:12:59
used and have to be driven to zero in probability um so functionally though
2:13:02
probability um so functionally though
2:13:02
probability um so functionally though nothing breaks we're using a bit more
2:13:05
nothing breaks we're using a bit more
2:13:05
nothing breaks we're using a bit more extra um memory but otherwise this is a
2:13:08
extra um memory but otherwise this is a
2:13:08
extra um memory but otherwise this is a harmless operation as far as I can tell
2:13:11
harmless operation as far as I can tell
2:13:11
harmless operation as far as I can tell but and we're adding calculation but
2:13:12
but and we're adding calculation but
2:13:12
but and we're adding calculation but it's running faster and it's running
2:13:14
it's running faster and it's running
2:13:14
it's running faster and it's running faster because as I mentioned in Cuda so
2:13:17
faster because as I mentioned in Cuda so
2:13:17
faster because as I mentioned in Cuda so many kernels use uh block tiles and
2:13:21
many kernels use uh block tiles and
2:13:21
many kernels use uh block tiles and these block towels are usually nice
2:13:22
these block towels are usually nice
2:13:22
these block towels are usually nice numbers uh so powers of two so
2:13:25
numbers uh so powers of two so
2:13:25
numbers uh so powers of two so calculations are done in like chunks of
2:13:26
calculations are done in like chunks of
2:13:26
calculations are done in like chunks of 64 or chunks of 32 and when your um when
2:13:31
64 or chunks of 32 and when your um when
2:13:31
64 or chunks of 32 and when your um when your desired calculation doesn't neatly
2:13:32
your desired calculation doesn't neatly
2:13:32
your desired calculation doesn't neatly fit into those block tiles um there are
2:13:36
fit into those block tiles um there are
2:13:36
fit into those block tiles um there are all kinds of boundary kernels that can
2:13:38
all kinds of boundary kernels that can
2:13:38
all kinds of boundary kernels that can kick in to like do the last part so
2:13:42
kick in to like do the last part so
2:13:42
kick in to like do the last part so basically in a lot of kernels they will
2:13:44
basically in a lot of kernels they will
2:13:44
basically in a lot of kernels they will chunk at up your input and they will do
2:13:46
chunk at up your input and they will do
2:13:46
chunk at up your input and they will do the nice part first and then they have a
2:13:47
the nice part first and then they have a
2:13:47
the nice part first and then they have a whole second second phase where they
2:13:50
whole second second phase where they
2:13:50
whole second second phase where they come back to any that like uh remains uh
2:13:53
come back to any that like uh remains uh
2:13:54
come back to any that like uh remains uh and then they process the remaining part
2:13:56
and then they process the remaining part
2:13:56
and then they process the remaining part and the kernels for that could be very
2:13:57
and the kernels for that could be very
2:13:57
and the kernels for that could be very inefficient and so you're basically um
2:14:00
inefficient and so you're basically um
2:14:00
inefficient and so you're basically um spinning up all this extra compute and
2:14:02
spinning up all this extra compute and
2:14:02
spinning up all this extra compute and is extremely inefficient so you might as
2:14:04
is extremely inefficient so you might as
2:14:04
is extremely inefficient so you might as well pad your inputs and um make it fit
2:14:07
well pad your inputs and um make it fit
2:14:07
well pad your inputs and um make it fit nicely and usually that empiric lens up
2:14:10
nicely and usually that empiric lens up
2:14:10
nicely and usually that empiric lens up actually running faster um so this is
2:14:13
actually running faster um so this is
2:14:13
actually running faster um so this is another example of a 4% Improvement that
2:14:16
another example of a 4% Improvement that
2:14:16
another example of a 4% Improvement that we've added and this is something that
2:14:18
we've added and this is something that
2:14:18
we've added and this is something that also torch compile did not find for us
2:14:21
also torch compile did not find for us
2:14:21
also torch compile did not find for us you would hope that torch compile at
2:14:22
you would hope that torch compile at
2:14:22
you would hope that torch compile at some point could figure an optimization
2:14:24
some point could figure an optimization
2:14:24
some point could figure an optimization like this out uh but for now uh this is
2:14:27
like this out uh but for now uh this is
2:14:27
like this out uh but for now uh this is it and I also have to point out that
2:14:28
it and I also have to point out that
2:14:28
it and I also have to point out that we're using pytorch nightly so that's
2:14:30
we're using pytorch nightly so that's
2:14:30
we're using pytorch nightly so that's why we're only seeing 4% if you're using
2:14:32
why we're only seeing 4% if you're using
2:14:33
why we're only seeing 4% if you're using pytorch 2.3.1 or earlier you would
2:14:36
pytorch 2.3.1 or earlier you would
2:14:36
pytorch 2.3.1 or earlier you would actually see something like 30%
2:14:37
actually see something like 30%
2:14:37
actually see something like 30% Improvement just from this change from
2:14:39
Improvement just from this change from
2:14:39
Improvement just from this change from changing it to from 50,000 to 57 to
2:14:43
changing it to from 50,000 to 57 to
2:14:43
changing it to from 50,000 to 57 to 50304 so again one of my favorite
2:14:47
50304 so again one of my favorite
2:14:47
50304 so again one of my favorite examples also of having to understand
2:14:49
examples also of having to understand
2:14:49
examples also of having to understand the under the hood and how it all works
2:14:51
the under the hood and how it all works
2:14:51
the under the hood and how it all works and to know what kinds of things to
2:14:52
and to know what kinds of things to
2:14:52
and to know what kinds of things to Tinker with to push the performance of
2:14:53
Tinker with to push the performance of
2:14:54
Tinker with to push the performance of your code okay so at this point we have
2:14:56
your code okay so at this point we have
2:14:56
your code okay so at this point we have improved the performance by about 11x
2:14:58
improved the performance by about 11x
2:14:58
improved the performance by about 11x right because we started at about 1,000
2:15:00
right because we started at about 1,000
2:15:00
right because we started at about 1,000 milliseconds per step and we're now down
2:15:02
milliseconds per step and we're now down
2:15:02
milliseconds per step and we're now down to like 93 milliseconds so that's uh
2:15:05
to like 93 milliseconds so that's uh
2:15:05
to like 93 milliseconds so that's uh quite good and we're uh doing a much
2:15:08
quite good and we're uh doing a much
2:15:08
quite good and we're uh doing a much better job of utilizing our GPU
2:15:09
better job of utilizing our GPU
2:15:09
better job of utilizing our GPU resources so I'm going to now turn to
2:15:12
resources so I'm going to now turn to
2:15:12
resources so I'm going to now turn to more algorithmic changes uh and
2:15:14
more algorithmic changes uh and
2:15:14
more algorithmic changes uh and improvements to the actual optimization
2:15:16
improvements to the actual optimization
2:15:16
improvements to the actual optimization itself and what we would like to do is
2:15:18
itself and what we would like to do is
2:15:18
itself and what we would like to do is we would like to follow the hyper
2:15:19
we would like to follow the hyper
2:15:19
we would like to follow the hyper parameters that are mentioned in the GP
2:15:20
parameters that are mentioned in the GP
2:15:20
parameters that are mentioned in the GP G2 or gpt2 gpt3 paper now sadly gpt2 is
2:15:26
G2 or gpt2 gpt3 paper now sadly gpt2 is
2:15:26
G2 or gpt2 gpt3 paper now sadly gpt2 is uh doesn't actually say too much it's
2:15:28
uh doesn't actually say too much it's
2:15:28
uh doesn't actually say too much it's very nice of them that they released the
2:15:30
very nice of them that they released the
2:15:30
very nice of them that they released the model weights and the code but the paper
2:15:32
model weights and the code but the paper
2:15:32
model weights and the code but the paper itself is extremely vague as to the
2:15:33
itself is extremely vague as to the
2:15:33
itself is extremely vague as to the optimization details uh the code itself
2:15:36
optimization details uh the code itself
2:15:36
optimization details uh the code itself that they released as well the code
2:15:38
that they released as well the code
2:15:38
that they released as well the code we've been looking at this is just the
2:15:40
we've been looking at this is just the
2:15:40
we've been looking at this is just the inference code so there's no training
2:15:41
inference code so there's no training
2:15:41
inference code so there's no training code here and very few hyp parameters so
2:15:44
code here and very few hyp parameters so
2:15:44
code here and very few hyp parameters so this doesn't also tell us too much so
2:15:46
this doesn't also tell us too much so
2:15:46
this doesn't also tell us too much so for that we have to turn to the gpt3
2:15:48
for that we have to turn to the gpt3
2:15:48
for that we have to turn to the gpt3 paper and um in the depending of the
2:15:51
paper and um in the depending of the
2:15:51
paper and um in the depending of the gpt3 paper um they have a lot more hyper
2:15:55
gpt3 paper um they have a lot more hyper
2:15:55
gpt3 paper um they have a lot more hyper parameters here for us to use and the
2:15:56
parameters here for us to use and the
2:15:57
parameters here for us to use and the gpt3 paper in general is a lot more
2:15:59
gpt3 paper in general is a lot more
2:15:59
gpt3 paper in general is a lot more detailed as to uh all of the you know
2:16:02
detailed as to uh all of the you know
2:16:02
detailed as to uh all of the you know small details that go into the model
2:16:04
small details that go into the model
2:16:04
small details that go into the model training but gpt3 U models were never
2:16:07
training but gpt3 U models were never
2:16:07
training but gpt3 U models were never released so gbt2 we have the weights but
2:16:09
released so gbt2 we have the weights but
2:16:10
released so gbt2 we have the weights but no details and gpt3 we have lots of
2:16:11
no details and gpt3 we have lots of
2:16:11
no details and gpt3 we have lots of details but no weights so um but roughly
2:16:15
details but no weights so um but roughly
2:16:15
details but no weights so um but roughly speaking gpt2 and gpt3 architectures are
2:16:17
speaking gpt2 and gpt3 architectures are
2:16:17
speaking gpt2 and gpt3 architectures are very very similar and um basically there
2:16:21
very very similar and um basically there
2:16:21
very very similar and um basically there are very few changes the context length
2:16:23
are very few changes the context length
2:16:23
are very few changes the context length was expanded from 1024 to 2048 and
2:16:25
was expanded from 1024 to 2048 and
2:16:25
was expanded from 1024 to 2048 and that's kind of like the major change uh
2:16:28
that's kind of like the major change uh
2:16:28
that's kind of like the major change uh and some of the hyper parameters around
2:16:29
and some of the hyper parameters around
2:16:29
and some of the hyper parameters around the Transformer have changed but
2:16:31
the Transformer have changed but
2:16:31
the Transformer have changed but otherwise they're pretty much the same
2:16:32
otherwise they're pretty much the same
2:16:32
otherwise they're pretty much the same model it's just that gpt3 was trained
2:16:34
model it's just that gpt3 was trained
2:16:34
model it's just that gpt3 was trained for a lot longer on a bigger data set
2:16:36
for a lot longer on a bigger data set
2:16:36
for a lot longer on a bigger data set and uh has a lot more thorough
2:16:38
and uh has a lot more thorough
2:16:38
and uh has a lot more thorough evaluations uh and the gpt3 model is 175
2:16:42
evaluations uh and the gpt3 model is 175
2:16:42
evaluations uh and the gpt3 model is 175 billion instead of 1.6 billion um in the
2:16:46
billion instead of 1.6 billion um in the
2:16:46
billion instead of 1.6 billion um in the gpt2 so long story short we're going to
2:16:49
gpt2 so long story short we're going to
2:16:49
gpt2 so long story short we're going to go to gp3 paper to follow along some the
2:16:51
go to gp3 paper to follow along some the
2:16:51
go to gp3 paper to follow along some the hyper parameters so to train all the
2:16:54
hyper parameters so to train all the
2:16:54
hyper parameters so to train all the versions of gpt3 we use atom with beta 1
2:16:56
versions of gpt3 we use atom with beta 1
2:16:56
versions of gpt3 we use atom with beta 1 beta 2 of9 and .95 so let's swing over
2:17:00
beta 2 of9 and .95 so let's swing over
2:17:00
beta 2 of9 and .95 so let's swing over here and make sure that the betas
2:17:02
here and make sure that the betas
2:17:02
here and make sure that the betas parameter which you can see here
2:17:04
parameter which you can see here
2:17:04
parameter which you can see here defaults to 0.9 and
2:17:06
defaults to 0.9 and
2:17:06
defaults to 0.9 and 999 is actually set to 0.9 and
2:17:11
999 is actually set to 0.9 and
2:17:11
999 is actually set to 0.9 and .95 and then the Epsilon parameter uh
2:17:14
.95 and then the Epsilon parameter uh
2:17:14
.95 and then the Epsilon parameter uh you can see is the default is 1 in8 and
2:17:17
you can see is the default is 1 in8 and
2:17:17
you can see is the default is 1 in8 and this is also one in8 let's just uh put
2:17:19
this is also one in8 let's just uh put
2:17:19
this is also one in8 let's just uh put it in so that works
2:17:22
it in so that works
2:17:22
it in so that works expit uh now next up they say we clip
2:17:25
expit uh now next up they say we clip
2:17:25
expit uh now next up they say we clip the gra Global Norm of the gradient at
2:17:27
the gra Global Norm of the gradient at
2:17:27
the gra Global Norm of the gradient at 1.0 so what this is referring to is that
2:17:30
1.0 so what this is referring to is that
2:17:30
1.0 so what this is referring to is that once we calculate the gradients right
2:17:32
once we calculate the gradients right
2:17:32
once we calculate the gradients right after l. backward um we basically have
2:17:35
after l. backward um we basically have
2:17:35
after l. backward um we basically have the gradients at all the parameter
2:17:37
the gradients at all the parameter
2:17:37
the gradients at all the parameter tensors and what people like to do is
2:17:40
tensors and what people like to do is
2:17:40
tensors and what people like to do is basically uh clip them to have some kind
2:17:42
basically uh clip them to have some kind
2:17:42
basically uh clip them to have some kind of a maximum Norm so in pytor this is
2:17:45
of a maximum Norm so in pytor this is
2:17:45
of a maximum Norm so in pytor this is fairly easy to do uh it's one line of
2:17:48
fairly easy to do uh it's one line of
2:17:48
fairly easy to do uh it's one line of code here that we have to insert right
2:17:50
code here that we have to insert right
2:17:50
code here that we have to insert right after we calcul Cal the gradients and
2:17:52
after we calcul Cal the gradients and
2:17:52
after we calcul Cal the gradients and what this utility function is doing is
2:17:55
what this utility function is doing is
2:17:55
what this utility function is doing is um it's calculating the global Norm of
2:17:58
um it's calculating the global Norm of
2:17:58
um it's calculating the global Norm of the parameters so every single par um
2:18:01
the parameters so every single par um
2:18:01
the parameters so every single par um gradient on all the parameters you
2:18:03
gradient on all the parameters you
2:18:03
gradient on all the parameters you square it and you add it all up and you
2:18:05
square it and you add it all up and you
2:18:05
square it and you add it all up and you take a big square root of that and
2:18:07
take a big square root of that and
2:18:07
take a big square root of that and that's the norm of the parameter V
2:18:10
that's the norm of the parameter V
2:18:10
that's the norm of the parameter V Vector basically it's the it's the
2:18:12
Vector basically it's the it's the
2:18:12
Vector basically it's the it's the length of it if you if you'd like to
2:18:14
length of it if you if you'd like to
2:18:14
length of it if you if you'd like to look at it that way and we are basically
2:18:16
look at it that way and we are basically
2:18:16
look at it that way and we are basically making sure that its length is no more
2:18:18
making sure that its length is no more
2:18:18
making sure that its length is no more than 1.0 and we're going to clip it
2:18:21
than 1.0 and we're going to clip it
2:18:21
than 1.0 and we're going to clip it and the reason that people like to use
2:18:23
and the reason that people like to use
2:18:23
and the reason that people like to use this is that uh sometimes you can get
2:18:25
this is that uh sometimes you can get
2:18:25
this is that uh sometimes you can get unlucky during your optimization maybe
2:18:27
unlucky during your optimization maybe
2:18:27
unlucky during your optimization maybe it's a bad data batch or something like
2:18:28
it's a bad data batch or something like
2:18:28
it's a bad data batch or something like that and if you get very unlucky in the
2:18:31
that and if you get very unlucky in the
2:18:31
that and if you get very unlucky in the batch you might get really high loss and
2:18:33
batch you might get really high loss and
2:18:33
batch you might get really high loss and really high loss could lead to a really
2:18:35
really high loss could lead to a really
2:18:35
really high loss could lead to a really high gradient and this could basically
2:18:38
high gradient and this could basically
2:18:38
high gradient and this could basically uh shock your model and shock the
2:18:40
uh shock your model and shock the
2:18:40
uh shock your model and shock the optimization so people like to use a
2:18:42
optimization so people like to use a
2:18:42
optimization so people like to use a gradient Norm clipping uh to prevent the
2:18:45
gradient Norm clipping uh to prevent the
2:18:45
gradient Norm clipping uh to prevent the model from um basically getting too big
2:18:49
model from um basically getting too big
2:18:49
model from um basically getting too big of shocks in terms of the gradient
2:18:50
of shocks in terms of the gradient
2:18:50
of shocks in terms of the gradient magnet ude and uh the upper bound it in
2:18:53
magnet ude and uh the upper bound it in
2:18:53
magnet ude and uh the upper bound it in this way it's a bit of a hacky solution
2:18:55
this way it's a bit of a hacky solution
2:18:55
this way it's a bit of a hacky solution it's about like a patch on top of like
2:18:57
it's about like a patch on top of like
2:18:57
it's about like a patch on top of like deeper issues uh but uh people still do
2:19:00
deeper issues uh but uh people still do
2:19:00
deeper issues uh but uh people still do it fairly frequently now the clip grad
2:19:03
it fairly frequently now the clip grad
2:19:03
it fairly frequently now the clip grad Norm Returns the norm of the gradient
2:19:05
Norm Returns the norm of the gradient
2:19:05
Norm Returns the norm of the gradient which I like to always visualize uh
2:19:08
which I like to always visualize uh
2:19:08
which I like to always visualize uh because um it is useful information and
2:19:11
because um it is useful information and
2:19:11
because um it is useful information and sometimes you can look at the norm of
2:19:13
sometimes you can look at the norm of
2:19:13
sometimes you can look at the norm of the gradient and if it's well behaved
2:19:15
the gradient and if it's well behaved
2:19:15
the gradient and if it's well behaved things are good if it's climbing things
2:19:17
things are good if it's climbing things
2:19:17
things are good if it's climbing things are bad and they're destabilizing during
2:19:18
are bad and they're destabilizing during
2:19:19
are bad and they're destabilizing during training sometimes you could get a spike
2:19:21
training sometimes you could get a spike
2:19:21
training sometimes you could get a spike in the norm and that means there's some
2:19:22
in the norm and that means there's some
2:19:22
in the norm and that means there's some kind of an issue or an instability so
2:19:25
kind of an issue or an instability so
2:19:25
kind of an issue or an instability so the norm here will be a
2:19:28
the norm here will be a
2:19:28
the norm here will be a norm uh and let's do a uh 4f or
2:19:33
norm uh and let's do a uh 4f or
2:19:33
norm uh and let's do a uh 4f or something like
2:19:34
something like
2:19:34
something like that and I believe this is just a float
2:19:37
that and I believe this is just a float
2:19:37
that and I believe this is just a float and so we should be able to uh print
2:19:40
and so we should be able to uh print
2:19:40
and so we should be able to uh print that uh so that's Global gradient
2:19:44
that uh so that's Global gradient
2:19:44
that uh so that's Global gradient clipping now they go into the details of
2:19:46
clipping now they go into the details of
2:19:46
clipping now they go into the details of the learning rate uh scheduler so they
2:19:49
the learning rate uh scheduler so they
2:19:49
the learning rate uh scheduler so they don't just use a fixed learning rate
2:19:51
don't just use a fixed learning rate
2:19:51
don't just use a fixed learning rate like we do here for 3 E4 but there's
2:19:54
like we do here for 3 E4 but there's
2:19:54
like we do here for 3 E4 but there's actually basically a cosine DK learning
2:19:56
actually basically a cosine DK learning
2:19:57
actually basically a cosine DK learning rate schedule um it's got a warm-up and
2:20:00
rate schedule um it's got a warm-up and
2:20:00
rate schedule um it's got a warm-up and it's got a cosine DEC to 10% over some
2:20:04
it's got a cosine DEC to 10% over some
2:20:04
it's got a cosine DEC to 10% over some Horizon
2:20:06
Horizon
2:20:06
Horizon um and so we're going to implement uh
2:20:09
um and so we're going to implement uh
2:20:09
um and so we're going to implement uh this in a second I just like to see Norm
2:20:11
this in a second I just like to see Norm
2:20:11
this in a second I just like to see Norm printed here okay there we go so what
2:20:14
printed here okay there we go so what
2:20:14
printed here okay there we go so what happened here is the norm is actually
2:20:16
happened here is the norm is actually
2:20:16
happened here is the norm is actually really high in the beginning 30 or so
2:20:19
really high in the beginning 30 or so
2:20:19
really high in the beginning 30 or so and you see that as we continue training
2:20:21
and you see that as we continue training
2:20:21
and you see that as we continue training it kind of like
2:20:22
it kind of like
2:20:22
it kind of like stabilizes um at values below one um and
2:20:27
stabilizes um at values below one um and
2:20:27
stabilizes um at values below one um and this is not that crazy uncommon for the
2:20:30
this is not that crazy uncommon for the
2:20:30
this is not that crazy uncommon for the norm to be high in the very first few
2:20:31
norm to be high in the very first few
2:20:31
norm to be high in the very first few stages basically What's Happening Here
2:20:33
stages basically What's Happening Here
2:20:33
stages basically What's Happening Here is the model is completely random and so
2:20:35
is the model is completely random and so
2:20:35
is the model is completely random and so there's a ton of learning happening very
2:20:37
there's a ton of learning happening very
2:20:37
there's a ton of learning happening very early in the network but that learning
2:20:38
early in the network but that learning
2:20:39
early in the network but that learning is kind of like um you know it's mostly
2:20:41
is kind of like um you know it's mostly
2:20:41
is kind of like um you know it's mostly learning the biases of the output tokens
2:20:44
learning the biases of the output tokens
2:20:44
learning the biases of the output tokens and so it's a bit of an unstable time uh
2:20:46
and so it's a bit of an unstable time uh
2:20:46
and so it's a bit of an unstable time uh but the network usually stabilizes in a
2:20:48
but the network usually stabilizes in a
2:20:48
but the network usually stabilizes in a very few iterations so this looks very
2:20:50
very few iterations so this looks very
2:20:50
very few iterations so this looks very relatively reasonable to me except
2:20:52
relatively reasonable to me except
2:20:52
relatively reasonable to me except usually I would expect this looks a
2:20:54
usually I would expect this looks a
2:20:54
usually I would expect this looks a little bit funky that we go from 28 to 6
2:20:56
little bit funky that we go from 28 to 6
2:20:56
little bit funky that we go from 28 to 6 to 2 and then to 10 um it's not
2:20:59
to 2 and then to 10 um it's not
2:20:59
to 2 and then to 10 um it's not completely insane but it's just kind of
2:21:01
completely insane but it's just kind of
2:21:01
completely insane but it's just kind of a little bit
2:21:02
a little bit
2:21:02
a little bit funky um okay so let's now get to the
2:21:05
funky um okay so let's now get to the
2:21:05
funky um okay so let's now get to the learning rate schuer so the learning
2:21:07
learning rate schuer so the learning
2:21:07
learning rate schuer so the learning rate schedule that's used here in gpt3
2:21:09
rate schedule that's used here in gpt3
2:21:09
rate schedule that's used here in gpt3 is what's called a cosine Decay learning
2:21:12
is what's called a cosine Decay learning
2:21:12
is what's called a cosine Decay learning schedule with warmup and the way this
2:21:14
schedule with warmup and the way this
2:21:14
schedule with warmup and the way this looks is that the learning rate is
2:21:17
looks is that the learning rate is
2:21:17
looks is that the learning rate is basically starts right at around zero
2:21:19
basically starts right at around zero
2:21:19
basically starts right at around zero linearly rank s up over some amount of
2:21:21
linearly rank s up over some amount of
2:21:21
linearly rank s up over some amount of time and then comes down with this
2:21:24
time and then comes down with this
2:21:24
time and then comes down with this cosine sort of form and comes down to
2:21:27
cosine sort of form and comes down to
2:21:27
cosine sort of form and comes down to some kind of a minimum learning rate
2:21:28
some kind of a minimum learning rate
2:21:28
some kind of a minimum learning rate that's up to you so here the minimum
2:21:30
that's up to you so here the minimum
2:21:30
that's up to you so here the minimum learning rate is zero but uh here in the
2:21:33
learning rate is zero but uh here in the
2:21:33
learning rate is zero but uh here in the paper they said that they use cosine
2:21:35
paper they said that they use cosine
2:21:35
paper they said that they use cosine Decay for learning rate down to 10% of
2:21:37
Decay for learning rate down to 10% of
2:21:37
Decay for learning rate down to 10% of its value over the first 260 billion
2:21:40
its value over the first 260 billion
2:21:40
its value over the first 260 billion tokens and then training continues 10%
2:21:43
tokens and then training continues 10%
2:21:43
tokens and then training continues 10% after and there's a linear warmup over
2:21:46
after and there's a linear warmup over
2:21:46
after and there's a linear warmup over the first 375 million tokens so that's
2:21:50
the first 375 million tokens so that's
2:21:50
the first 375 million tokens so that's about the learn R so let's now implement
2:21:52
about the learn R so let's now implement
2:21:52
about the learn R so let's now implement this uh so I already implemented it here
2:21:55
this uh so I already implemented it here
2:21:55
this uh so I already implemented it here and the way this works is let me scroll
2:21:58
and the way this works is let me scroll
2:21:58
and the way this works is let me scroll down first here I changed our training
2:22:00
down first here I changed our training
2:22:00
down first here I changed our training Loop a little bit so this was a 4i in
2:22:02
Loop a little bit so this was a 4i in
2:22:02
Loop a little bit so this was a 4i in Max steps I just change it to step now
2:22:04
Max steps I just change it to step now
2:22:04
Max steps I just change it to step now so that we have the notion of a step is
2:22:06
so that we have the notion of a step is
2:22:07
so that we have the notion of a step is a single optimization step in the in the
2:22:09
a single optimization step in the in the
2:22:09
a single optimization step in the in the for Loop and then here I get the LR for
2:22:13
for Loop and then here I get the LR for
2:22:13
for Loop and then here I get the LR for this step of the optimization using a
2:22:15
this step of the optimization using a
2:22:15
this step of the optimization using a new function I call get LR and then in
2:22:18
new function I call get LR and then in
2:22:18
new function I call get LR and then in pytorch to set the learning rate I think
2:22:20
pytorch to set the learning rate I think
2:22:20
pytorch to set the learning rate I think this is is the way to set the learning
2:22:21
this is is the way to set the learning
2:22:21
this is is the way to set the learning rate it's a little bit gnarly um because
2:22:23
rate it's a little bit gnarly um because
2:22:24
rate it's a little bit gnarly um because you have to basically there's a notion
2:22:25
you have to basically there's a notion
2:22:25
you have to basically there's a notion of different par parameter groups that
2:22:27
of different par parameter groups that
2:22:27
of different par parameter groups that could exist in the optimizer and so you
2:22:28
could exist in the optimizer and so you
2:22:28
could exist in the optimizer and so you actually have to iterate over them even
2:22:30
actually have to iterate over them even
2:22:30
actually have to iterate over them even though we currently have a single param
2:22:32
though we currently have a single param
2:22:32
though we currently have a single param group only um and you have to set the LR
2:22:34
group only um and you have to set the LR
2:22:34
group only um and you have to set the LR in this for Loop kind of style is is my
2:22:37
in this for Loop kind of style is is my
2:22:37
in this for Loop kind of style is is my impression right now so we have this
2:22:39
impression right now so we have this
2:22:39
impression right now so we have this look of LR we set the learning rate and
2:22:42
look of LR we set the learning rate and
2:22:42
look of LR we set the learning rate and then on the bottom I'm also printing it
2:22:45
then on the bottom I'm also printing it
2:22:45
then on the bottom I'm also printing it uh so that's all the changes I made to
2:22:47
uh so that's all the changes I made to
2:22:47
uh so that's all the changes I made to this Loop and then of course the get LR
2:22:49
this Loop and then of course the get LR
2:22:49
this Loop and then of course the get LR is my scheduler now it's worth pointing
2:22:51
is my scheduler now it's worth pointing
2:22:51
is my scheduler now it's worth pointing out that pytorch actually has learning
2:22:53
out that pytorch actually has learning
2:22:53
out that pytorch actually has learning rate schedulers and you can use them and
2:22:55
rate schedulers and you can use them and
2:22:55
rate schedulers and you can use them and I believe there's a cosine learning rate
2:22:56
I believe there's a cosine learning rate
2:22:57
I believe there's a cosine learning rate schedule in pytorch I just don't really
2:22:59
schedule in pytorch I just don't really
2:22:59
schedule in pytorch I just don't really love using that code because honestly
2:23:02
love using that code because honestly
2:23:02
love using that code because honestly it's like five lines of code and I fully
2:23:06
it's like five lines of code and I fully
2:23:06
it's like five lines of code and I fully understand what's happening inside these
2:23:07
understand what's happening inside these
2:23:07
understand what's happening inside these lines so I don't love to use
2:23:09
lines so I don't love to use
2:23:09
lines so I don't love to use abstractions where they're kind of in
2:23:11
abstractions where they're kind of in
2:23:11
abstractions where they're kind of in screwable and then I don't know what
2:23:13
screwable and then I don't know what
2:23:13
screwable and then I don't know what they're doing so personal style so the
2:23:16
they're doing so personal style so the
2:23:16
they're doing so personal style so the max learning rate here is let's say 3 E4
2:23:19
max learning rate here is let's say 3 E4
2:23:19
max learning rate here is let's say 3 E4 but we're going to see that in gpt3
2:23:22
but we're going to see that in gpt3
2:23:22
but we're going to see that in gpt3 here they have a table of what the
2:23:25
here they have a table of what the
2:23:25
here they have a table of what the maximum learning rate is for every model
2:23:28
maximum learning rate is for every model
2:23:28
maximum learning rate is for every model size so um for for this one basically 12
2:23:34
size so um for for this one basically 12
2:23:34
size so um for for this one basically 12 12 layer 768 gpt3 so the gpt3 small is
2:23:37
12 layer 768 gpt3 so the gpt3 small is
2:23:37
12 layer 768 gpt3 so the gpt3 small is roughly like a GPT
2:23:40
roughly like a GPT
2:23:40
roughly like a GPT 2124m we see that here they use a
2:23:42
2124m we see that here they use a
2:23:42
2124m we see that here they use a learning rate of 6 E4 so we could
2:23:44
learning rate of 6 E4 so we could
2:23:44
learning rate of 6 E4 so we could actually go higher um in fact we may
2:23:46
actually go higher um in fact we may
2:23:46
actually go higher um in fact we may want to try to follow that and just set
2:23:48
want to try to follow that and just set
2:23:48
want to try to follow that and just set the max LR here at six
2:23:51
the max LR here at six
2:23:51
the max LR here at six uh then the that's the maximum learning
2:23:53
uh then the that's the maximum learning
2:23:53
uh then the that's the maximum learning rate the minum learning rate is uh 10%
2:23:55
rate the minum learning rate is uh 10%
2:23:55
rate the minum learning rate is uh 10% of that per description in the paper
2:23:58
of that per description in the paper
2:23:58
of that per description in the paper some number of steps that we're going to
2:24:00
some number of steps that we're going to
2:24:00
some number of steps that we're going to warm up over and then the maximum steps
2:24:02
warm up over and then the maximum steps
2:24:02
warm up over and then the maximum steps of the optimization which I now use also
2:24:05
of the optimization which I now use also
2:24:05
of the optimization which I now use also in the for Loop down here and then you
2:24:07
in the for Loop down here and then you
2:24:07
in the for Loop down here and then you can go over this code if you like it's
2:24:09
can go over this code if you like it's
2:24:09
can go over this code if you like it's not U it's not terribly inside Flor
2:24:11
not U it's not terribly inside Flor
2:24:11
not U it's not terribly inside Flor interesting I'm just uh modulating based
2:24:13
interesting I'm just uh modulating based
2:24:13
interesting I'm just uh modulating based on the iteration number which learning
2:24:16
on the iteration number which learning
2:24:16
on the iteration number which learning rate uh there should be so this is the
2:24:18
rate uh there should be so this is the
2:24:18
rate uh there should be so this is the warm-up region um
2:24:21
warm-up region um
2:24:21
warm-up region um this is the region after the
2:24:22
this is the region after the
2:24:22
this is the region after the optimization and then this is the region
2:24:24
optimization and then this is the region
2:24:24
optimization and then this is the region sort of in between and this is where I
2:24:26
sort of in between and this is where I
2:24:26
sort of in between and this is where I calculate the cosine learning rate
2:24:28
calculate the cosine learning rate
2:24:28
calculate the cosine learning rate schedule and you can step through this
2:24:29
schedule and you can step through this
2:24:29
schedule and you can step through this in detail if you'd like uh but this is
2:24:31
in detail if you'd like uh but this is
2:24:32
in detail if you'd like uh but this is basically implementing this
2:24:33
basically implementing this
2:24:33
basically implementing this curve and I ran this already and this is
2:24:38
curve and I ran this already and this is
2:24:38
curve and I ran this already and this is what that looks
2:24:40
what that looks
2:24:40
what that looks like um so when we now run we start at
2:24:45
like um so when we now run we start at
2:24:45
like um so when we now run we start at um some very low number now note that we
2:24:47
um some very low number now note that we
2:24:47
um some very low number now note that we don't start exactly at zero because that
2:24:49
don't start exactly at zero because that
2:24:49
don't start exactly at zero because that would be not useful to update with a
2:24:50
would be not useful to update with a
2:24:50
would be not useful to update with a learning rate of zero that's why there's
2:24:52
learning rate of zero that's why there's
2:24:52
learning rate of zero that's why there's an it+ one so that on the zeroth
2:24:54
an it+ one so that on the zeroth
2:24:54
an it+ one so that on the zeroth iteration we are not using exactly zero
2:24:57
iteration we are not using exactly zero
2:24:57
iteration we are not using exactly zero we're using something very very low then
2:24:59
we're using something very very low then
2:24:59
we're using something very very low then we linearly warm up to maximum learning
2:25:02
we linearly warm up to maximum learning
2:25:02
we linearly warm up to maximum learning rate which in this case was 34 when I
2:25:04
rate which in this case was 34 when I
2:25:04
rate which in this case was 34 when I ran it but now would be 6 E4 and then it
2:25:07
ran it but now would be 6 E4 and then it
2:25:07
ran it but now would be 6 E4 and then it starts to decay all the way down to um 3
2:25:11
starts to decay all the way down to um 3
2:25:11
starts to decay all the way down to um 3 E5 which was at the time 10% of the
2:25:14
E5 which was at the time 10% of the
2:25:14
E5 which was at the time 10% of the original learning rate now one thing we
2:25:16
original learning rate now one thing we
2:25:16
original learning rate now one thing we are not following exactly is that they
2:25:18
are not following exactly is that they
2:25:18
are not following exactly is that they mentioned that um
2:25:21
mentioned that um
2:25:21
mentioned that um let me see if I can find it
2:25:23
let me see if I can find it
2:25:23
let me see if I can find it again we're not exactly following what
2:25:26
again we're not exactly following what
2:25:26
again we're not exactly following what they did
2:25:28
they did
2:25:28
they did because uh they mentioned that their
2:25:30
because uh they mentioned that their
2:25:30
because uh they mentioned that their training Horizon is 300 billion tokens
2:25:33
training Horizon is 300 billion tokens
2:25:33
training Horizon is 300 billion tokens and they come down to 10% of the initial
2:25:34
and they come down to 10% of the initial
2:25:35
and they come down to 10% of the initial learning rate of at 260 billion and then
2:25:37
learning rate of at 260 billion and then
2:25:37
learning rate of at 260 billion and then they train after 260 with 10% so
2:25:41
they train after 260 with 10% so
2:25:41
they train after 260 with 10% so basically their Decay time is less than
2:25:43
basically their Decay time is less than
2:25:43
basically their Decay time is less than the max steps time whereas for us
2:25:45
the max steps time whereas for us
2:25:45
the max steps time whereas for us they're exactly equal so it's not
2:25:47
they're exactly equal so it's not
2:25:47
they're exactly equal so it's not exactly faithful but it's um it's an
2:25:50
exactly faithful but it's um it's an
2:25:51
exactly faithful but it's um it's an okay um this is okay for us and for our
2:25:53
okay um this is okay for us and for our
2:25:53
okay um this is okay for us and for our purposes right now and um we're just
2:25:57
purposes right now and um we're just
2:25:57
purposes right now and um we're just going to use this ourselves I don't
2:25:58
going to use this ourselves I don't
2:25:58
going to use this ourselves I don't think it makes too too big of a
2:26:00
think it makes too too big of a
2:26:00
think it makes too too big of a difference honestly I should point out
2:26:02
difference honestly I should point out
2:26:02
difference honestly I should point out that what learning rate schedule you use
2:26:04
that what learning rate schedule you use
2:26:04
that what learning rate schedule you use is totally up to you there's many
2:26:05
is totally up to you there's many
2:26:05
is totally up to you there's many different types um coign learning rate
2:26:08
different types um coign learning rate
2:26:08
different types um coign learning rate has been popularized a lot by gpt2 and
2:26:10
has been popularized a lot by gpt2 and
2:26:10
has been popularized a lot by gpt2 and gpt3 but people have come up with all
2:26:12
gpt3 but people have come up with all
2:26:12
gpt3 but people have come up with all kinds of uh other learning rate
2:26:14
kinds of uh other learning rate
2:26:14
kinds of uh other learning rate schedules um and this is kind of like an
2:26:16
schedules um and this is kind of like an
2:26:16
schedules um and this is kind of like an active area of uh research as to which
2:26:18
active area of uh research as to which
2:26:18
active area of uh research as to which one is the most effective at train these
2:26:20
one is the most effective at train these
2:26:20
one is the most effective at train these networks okay next up the paper talks
2:26:23
networks okay next up the paper talks
2:26:23
networks okay next up the paper talks about the gradual batch size increase so
2:26:26
about the gradual batch size increase so
2:26:26
about the gradual batch size increase so there's a ramp on the batch size that is
2:26:29
there's a ramp on the batch size that is
2:26:29
there's a ramp on the batch size that is linear and you start with very small
2:26:31
linear and you start with very small
2:26:31
linear and you start with very small batch size and you ramp up to a big
2:26:32
batch size and you ramp up to a big
2:26:32
batch size and you ramp up to a big batch size over time uh we're going to
2:26:35
batch size over time uh we're going to
2:26:35
batch size over time uh we're going to actually skip this and we're not going
2:26:36
actually skip this and we're not going
2:26:36
actually skip this and we're not going to work with it and the reason I don't
2:26:38
to work with it and the reason I don't
2:26:38
to work with it and the reason I don't love to use it is that it complicates a
2:26:41
love to use it is that it complicates a
2:26:41
love to use it is that it complicates a lot of the arithmetic because you are
2:26:42
lot of the arithmetic because you are
2:26:42
lot of the arithmetic because you are changing the number of tokens that
2:26:43
changing the number of tokens that
2:26:43
changing the number of tokens that you're processing at every single step
2:26:45
you're processing at every single step
2:26:45
you're processing at every single step of the optimization and I like to keep
2:26:47
of the optimization and I like to keep
2:26:47
of the optimization and I like to keep that math very very simple also my
2:26:49
that math very very simple also my
2:26:49
that math very very simple also my understanding is that that this is not
2:26:50
understanding is that that this is not
2:26:50
understanding is that that this is not like a major um Improvement and also my
2:26:54
like a major um Improvement and also my
2:26:54
like a major um Improvement and also my understanding is that this is not like
2:26:55
understanding is that this is not like
2:26:55
understanding is that this is not like an algorithmic optimization Improvement
2:26:57
an algorithmic optimization Improvement
2:26:57
an algorithmic optimization Improvement it's more of a systems and speed
2:26:59
it's more of a systems and speed
2:26:59
it's more of a systems and speed Improvement and roughly speaking this is
2:27:02
Improvement and roughly speaking this is
2:27:02
Improvement and roughly speaking this is because uh in the early stages of the
2:27:05
because uh in the early stages of the
2:27:05
because uh in the early stages of the optimization uh again the model is in a
2:27:07
optimization uh again the model is in a
2:27:07
optimization uh again the model is in a very atypical setting and mostly what
2:27:10
very atypical setting and mostly what
2:27:10
very atypical setting and mostly what you're learning is that um you're mostly
2:27:13
you're learning is that um you're mostly
2:27:13
you're learning is that um you're mostly learning to ignore the tokens uh that
2:27:15
learning to ignore the tokens uh that
2:27:15
learning to ignore the tokens uh that don't come up in your training set very
2:27:16
don't come up in your training set very
2:27:16
don't come up in your training set very often you're learning very simple biases
2:27:19
often you're learning very simple biases
2:27:19
often you're learning very simple biases and and that kind of a thing and so
2:27:23
and and that kind of a thing and so
2:27:23
and and that kind of a thing and so every single example that you put
2:27:24
every single example that you put
2:27:24
every single example that you put through your network is basically just
2:27:26
through your network is basically just
2:27:26
through your network is basically just telling you use these tokens and don't
2:27:28
telling you use these tokens and don't
2:27:28
telling you use these tokens and don't use these tokens and so the gradients
2:27:30
use these tokens and so the gradients
2:27:30
use these tokens and so the gradients from every single example are actually
2:27:31
from every single example are actually
2:27:31
from every single example are actually extremely highly correlated they all
2:27:33
extremely highly correlated they all
2:27:33
extremely highly correlated they all look roughly the same in the in the OR
2:27:36
look roughly the same in the in the OR
2:27:36
look roughly the same in the in the OR original parts of the optimization
2:27:38
original parts of the optimization
2:27:38
original parts of the optimization because they're all just telling you
2:27:38
because they're all just telling you
2:27:39
because they're all just telling you that these tokens don't appear and these
2:27:40
that these tokens don't appear and these
2:27:40
that these tokens don't appear and these tokens do appear and so because the
2:27:42
tokens do appear and so because the
2:27:43
tokens do appear and so because the gradients are all very similar and
2:27:45
gradients are all very similar and
2:27:45
gradients are all very similar and they're highly correlated then why are
2:27:46
they're highly correlated then why are
2:27:46
they're highly correlated then why are you doing batch sizes of like Millions
2:27:49
you doing batch sizes of like Millions
2:27:49
you doing batch sizes of like Millions when if you do a batch size of 32k
2:27:51
when if you do a batch size of 32k
2:27:51
when if you do a batch size of 32k you're basically getting the exact same
2:27:53
you're basically getting the exact same
2:27:53
you're basically getting the exact same gradient early on in the training and
2:27:55
gradient early on in the training and
2:27:55
gradient early on in the training and then later in the optimization once
2:27:57
then later in the optimization once
2:27:57
then later in the optimization once you've learned all the simple stuff
2:28:00
you've learned all the simple stuff
2:28:00
you've learned all the simple stuff that's where the actual work starts and
2:28:01
that's where the actual work starts and
2:28:01
that's where the actual work starts and that's where the gradients become more
2:28:02
that's where the gradients become more
2:28:02
that's where the gradients become more decorrelated per examples and that's
2:28:04
decorrelated per examples and that's
2:28:04
decorrelated per examples and that's where they actually offer you sort of
2:28:07
where they actually offer you sort of
2:28:07
where they actually offer you sort of statistical power in some sense um so
2:28:10
statistical power in some sense um so
2:28:10
statistical power in some sense um so we're going to skip this just because it
2:28:12
we're going to skip this just because it
2:28:12
we're going to skip this just because it kind of complicates things and we're
2:28:14
kind of complicates things and we're
2:28:14
kind of complicates things and we're going to go
2:28:15
going to go
2:28:15
going to go to uh data are sampled without
2:28:18
to uh data are sampled without
2:28:18
to uh data are sampled without replacement during training um so until
2:28:21
replacement during training um so until
2:28:21
replacement during training um so until an Epoch boundary is reached so without
2:28:23
an Epoch boundary is reached so without
2:28:23
an Epoch boundary is reached so without replacement means that they're not
2:28:24
replacement means that they're not
2:28:24
replacement means that they're not sampling from some fixed pool and then
2:28:27
sampling from some fixed pool and then
2:28:27
sampling from some fixed pool and then uh take a sequence train on it but then
2:28:31
uh take a sequence train on it but then
2:28:31
uh take a sequence train on it but then also like return the sequence to the
2:28:32
also like return the sequence to the
2:28:32
also like return the sequence to the pool they are exhausting a pool so when
2:28:34
pool they are exhausting a pool so when
2:28:34
pool they are exhausting a pool so when they draw a sequence it's it's gone
2:28:37
they draw a sequence it's it's gone
2:28:37
they draw a sequence it's it's gone until the next Epoch of training uh so
2:28:39
until the next Epoch of training uh so
2:28:39
until the next Epoch of training uh so we're already doing that because our
2:28:41
we're already doing that because our
2:28:41
we're already doing that because our data loader um iterates over chunks of
2:28:44
data loader um iterates over chunks of
2:28:44
data loader um iterates over chunks of data so there's no replacement they
2:28:47
data so there's no replacement they
2:28:47
data so there's no replacement they don't become eligible to be drawn again
2:28:49
don't become eligible to be drawn again
2:28:49
don't become eligible to be drawn again until the next P so we're basically
2:28:51
until the next P so we're basically
2:28:51
until the next P so we're basically already doing
2:28:53
already doing
2:28:53
already doing that um all models use a weight decay of
2:28:56
that um all models use a weight decay of
2:28:56
that um all models use a weight decay of 0.1 to provide a small amount of
2:28:59
0.1 to provide a small amount of
2:28:59
0.1 to provide a small amount of regularization so let's Implement a
2:29:01
regularization so let's Implement a
2:29:01
regularization so let's Implement a weight Decay and you see here that I've
2:29:02
weight Decay and you see here that I've
2:29:03
weight Decay and you see here that I've already kind of made the changes and in
2:29:04
already kind of made the changes and in
2:29:04
already kind of made the changes and in particular instead of creating the
2:29:06
particular instead of creating the
2:29:06
particular instead of creating the optimizer right here um I I'm creating a
2:29:10
optimizer right here um I I'm creating a
2:29:10
optimizer right here um I I'm creating a new configure optimizers function inside
2:29:12
new configure optimizers function inside
2:29:12
new configure optimizers function inside the model and I'm passing in some of the
2:29:14
the model and I'm passing in some of the
2:29:14
the model and I'm passing in some of the hyper parameters instead so let's look
2:29:17
hyper parameters instead so let's look
2:29:17
hyper parameters instead so let's look at the configure optimizers which is
2:29:18
at the configure optimizers which is
2:29:18
at the configure optimizers which is supposed to return the optimizer
2:29:24
object okay so it looks complicated but
2:29:27
object okay so it looks complicated but
2:29:27
object okay so it looks complicated but it's actually really simple and it's
2:29:29
it's actually really simple and it's
2:29:29
it's actually really simple and it's just um we're just being very careful
2:29:31
just um we're just being very careful
2:29:31
just um we're just being very careful and there's a few settings here to go
2:29:32
and there's a few settings here to go
2:29:32
and there's a few settings here to go through the most important thing with
2:29:34
through the most important thing with
2:29:34
through the most important thing with respect to this line is that you see
2:29:36
respect to this line is that you see
2:29:36
respect to this line is that you see there's a weight Decay parameter here
2:29:38
there's a weight Decay parameter here
2:29:38
there's a weight Decay parameter here and I'm passing that
2:29:41
and I'm passing that
2:29:41
and I'm passing that into um well I'm passing that into
2:29:44
into um well I'm passing that into
2:29:44
into um well I'm passing that into something called optim groups that
2:29:46
something called optim groups that
2:29:46
something called optim groups that eventually ends up going into the addom
2:29:47
eventually ends up going into the addom
2:29:47
eventually ends up going into the addom W Optimizer um and the weight Decay
2:29:50
W Optimizer um and the weight Decay
2:29:50
W Optimizer um and the weight Decay that's by default used in Addam W here
2:29:52
that's by default used in Addam W here
2:29:53
that's by default used in Addam W here is 0.01 so it's it's u 10 times lower
2:29:57
is 0.01 so it's it's u 10 times lower
2:29:57
is 0.01 so it's it's u 10 times lower than what's used in gpt3 paper here um
2:30:01
than what's used in gpt3 paper here um
2:30:01
than what's used in gpt3 paper here um so the weight dek basically ends up
2:30:02
so the weight dek basically ends up
2:30:02
so the weight dek basically ends up making its way into the ADD and W
2:30:03
making its way into the ADD and W
2:30:04
making its way into the ADD and W through the optimizer groups now what
2:30:05
through the optimizer groups now what
2:30:05
through the optimizer groups now what else is going on here in this uh
2:30:07
else is going on here in this uh
2:30:07
else is going on here in this uh function so the two things that are
2:30:09
function so the two things that are
2:30:09
function so the two things that are happening here that are important is
2:30:10
happening here that are important is
2:30:10
happening here that are important is that I'm splitting up the parameters
2:30:12
that I'm splitting up the parameters
2:30:12
that I'm splitting up the parameters into those that should be weight decayed
2:30:14
into those that should be weight decayed
2:30:14
into those that should be weight decayed and those that should not be weight
2:30:15
and those that should not be weight
2:30:15
and those that should not be weight decayed so in particular it is common to
2:30:18
decayed so in particular it is common to
2:30:18
decayed so in particular it is common to not weight decay uh biases and any other
2:30:22
not weight decay uh biases and any other
2:30:22
not weight decay uh biases and any other sort of one-dimensional tensors so the
2:30:25
sort of one-dimensional tensors so the
2:30:25
sort of one-dimensional tensors so the one-dimensional tensors are in the no
2:30:27
one-dimensional tensors are in the no
2:30:27
one-dimensional tensors are in the no Decay prams and these are also things
2:30:30
Decay prams and these are also things
2:30:30
Decay prams and these are also things like uh layer Norm scales and biases it
2:30:33
like uh layer Norm scales and biases it
2:30:33
like uh layer Norm scales and biases it doesn't really make sense to weight
2:30:34
doesn't really make sense to weight
2:30:34
doesn't really make sense to weight Decay those you mostly want to weight
2:30:36
Decay those you mostly want to weight
2:30:36
Decay those you mostly want to weight Decay uh the weights that participate in
2:30:39
Decay uh the weights that participate in
2:30:39
Decay uh the weights that participate in Matrix multiplications and you want to
2:30:41
Matrix multiplications and you want to
2:30:41
Matrix multiplications and you want to potentially weight Decay the
2:30:43
potentially weight Decay the
2:30:43
potentially weight Decay the embeddings and uh We've covered in
2:30:45
embeddings and uh We've covered in
2:30:46
embeddings and uh We've covered in previous video why it makes sense to
2:30:47
previous video why it makes sense to
2:30:47
previous video why it makes sense to Decay the weights because you can sort
2:30:49
Decay the weights because you can sort
2:30:49
Decay the weights because you can sort of the it as a regularization because
2:30:51
of the it as a regularization because
2:30:51
of the it as a regularization because when you're pulling down all the weights
2:30:53
when you're pulling down all the weights
2:30:53
when you're pulling down all the weights you're forcing the optimization to use
2:30:55
you're forcing the optimization to use
2:30:55
you're forcing the optimization to use more of the weights um and you're not
2:30:57
more of the weights um and you're not
2:30:57
more of the weights um and you're not allowing any one of the weights
2:30:59
allowing any one of the weights
2:30:59
allowing any one of the weights individually to be way too large um
2:31:02
individually to be way too large um
2:31:02
individually to be way too large um you're forcing you're forcing the
2:31:03
you're forcing you're forcing the
2:31:03
you're forcing you're forcing the network to kind of like distribute the
2:31:05
network to kind of like distribute the
2:31:05
network to kind of like distribute the work across more channels because
2:31:07
work across more channels because
2:31:07
work across more channels because there's sort of like a pull of gravity
2:31:09
there's sort of like a pull of gravity
2:31:09
there's sort of like a pull of gravity on the weights
2:31:11
on the weights
2:31:11
on the weights themselves um so that's why we are
2:31:13
themselves um so that's why we are
2:31:13
themselves um so that's why we are separating it in those ways here we're
2:31:16
separating it in those ways here we're
2:31:16
separating it in those ways here we're only decaying the embeddings and the
2:31:18
only decaying the embeddings and the
2:31:18
only decaying the embeddings and the mmal participating ways
2:31:20
mmal participating ways
2:31:21
mmal participating ways uh we're printing the number of uh
2:31:22
uh we're printing the number of uh
2:31:22
uh we're printing the number of uh parameters that we decaying and not most
2:31:24
parameters that we decaying and not most
2:31:24
parameters that we decaying and not most of the parameters will be decayed and
2:31:26
of the parameters will be decayed and
2:31:26
of the parameters will be decayed and then one more thing that we're doing
2:31:27
then one more thing that we're doing
2:31:27
then one more thing that we're doing here is I'm doing another optimization
2:31:31
here is I'm doing another optimization
2:31:31
here is I'm doing another optimization here and previous add and W did not have
2:31:34
here and previous add and W did not have
2:31:34
here and previous add and W did not have this option but later parts of pytorch
2:31:37
this option but later parts of pytorch
2:31:37
this option but later parts of pytorch introduced it and that's why I'm
2:31:38
introduced it and that's why I'm
2:31:38
introduced it and that's why I'm guarding it with an inspect do signature
2:31:41
guarding it with an inspect do signature
2:31:41
guarding it with an inspect do signature which is basically checking if this
2:31:43
which is basically checking if this
2:31:43
which is basically checking if this fused um quar is present inside atom W
2:31:48
fused um quar is present inside atom W
2:31:48
fused um quar is present inside atom W and then if it is present I'm going to
2:31:50
and then if it is present I'm going to
2:31:50
and then if it is present I'm going to end up using it and passing it in here
2:31:53
end up using it and passing it in here
2:31:53
end up using it and passing it in here because some earlier versions do not
2:31:54
because some earlier versions do not
2:31:55
because some earlier versions do not have fused equals so here's adamw fused
2:31:58
have fused equals so here's adamw fused
2:31:58
have fused equals so here's adamw fused equals it did not used to exist and it
2:32:00
equals it did not used to exist and it
2:32:00
equals it did not used to exist and it was added later and there's some docks
2:32:03
was added later and there's some docks
2:32:03
was added later and there's some docks here for what's happening and basically
2:32:05
here for what's happening and basically
2:32:05
here for what's happening and basically they say that by default they do not use
2:32:07
they say that by default they do not use
2:32:07
they say that by default they do not use fused because it is relatively new and
2:32:10
fused because it is relatively new and
2:32:10
fused because it is relatively new and we want to give it sufficient big time
2:32:12
we want to give it sufficient big time
2:32:12
we want to give it sufficient big time so by default they don't use fused but
2:32:13
so by default they don't use fused but
2:32:13
so by default they don't use fused but fused is a lot faster when it is
2:32:15
fused is a lot faster when it is
2:32:15
fused is a lot faster when it is available and when you're running on
2:32:17
available and when you're running on
2:32:17
available and when you're running on Cuda and what that does is in instead of
2:32:20
Cuda and what that does is in instead of
2:32:20
Cuda and what that does is in instead of iterating in a for Loop over all the
2:32:22
iterating in a for Loop over all the
2:32:22
iterating in a for Loop over all the parameter tensors and updating them that
2:32:25
parameter tensors and updating them that
2:32:25
parameter tensors and updating them that would launch a lot of kernels right and
2:32:27
would launch a lot of kernels right and
2:32:27
would launch a lot of kernels right and so a fused just means that it's a um all
2:32:29
so a fused just means that it's a um all
2:32:30
so a fused just means that it's a um all those kernels are fused into a single
2:32:31
those kernels are fused into a single
2:32:31
those kernels are fused into a single kernel you get rid of a lot of overhead
2:32:34
kernel you get rid of a lot of overhead
2:32:34
kernel you get rid of a lot of overhead and you a single time on all the
2:32:36
and you a single time on all the
2:32:36
and you a single time on all the parameters call a uh kernel that updates
2:32:39
parameters call a uh kernel that updates
2:32:39
parameters call a uh kernel that updates them and so it's just basically a kernel
2:32:42
them and so it's just basically a kernel
2:32:42
them and so it's just basically a kernel Fusion for the atom W update instead of
2:32:44
Fusion for the atom W update instead of
2:32:44
Fusion for the atom W update instead of iterating over all the
2:32:47
iterating over all the
2:32:47
iterating over all the tensors so that's the configure
2:32:48
tensors so that's the configure
2:32:48
tensors so that's the configure optimizers function that I like to use
2:32:51
optimizers function that I like to use
2:32:51
optimizers function that I like to use and we can rerun and we're not going to
2:32:53
and we can rerun and we're not going to
2:32:53
and we can rerun and we're not going to see any major differences from what we
2:32:55
see any major differences from what we
2:32:55
see any major differences from what we saw before but we are going to see some
2:32:57
saw before but we are going to see some
2:32:57
saw before but we are going to see some prints uh coming from here so let's just
2:33:00
prints uh coming from here so let's just
2:33:00
prints uh coming from here so let's just take a look at what they look
2:33:01
take a look at what they look
2:33:01
take a look at what they look like so we see that number of Decay
2:33:04
like so we see that number of Decay
2:33:04
like so we see that number of Decay tensors is 50 and it's most of the
2:33:06
tensors is 50 and it's most of the
2:33:06
tensors is 50 and it's most of the parameters and number of non- deay
2:33:08
parameters and number of non- deay
2:33:08
parameters and number of non- deay tensors is 98 and these are the biases
2:33:10
tensors is 98 and these are the biases
2:33:10
tensors is 98 and these are the biases and the layer Norm parameters mostly and
2:33:13
and the layer Norm parameters mostly and
2:33:13
and the layer Norm parameters mostly and that's there's only 100,000 of those so
2:33:15
that's there's only 100,000 of those so
2:33:15
that's there's only 100,000 of those so most of it is decayed and then we are
2:33:18
most of it is decayed and then we are
2:33:18
most of it is decayed and then we are using the fused implementation of ATM W
2:33:20
using the fused implementation of ATM W
2:33:20
using the fused implementation of ATM W which will be a lot faster so if you
2:33:22
which will be a lot faster so if you
2:33:22
which will be a lot faster so if you have it available I would advise you to
2:33:24
have it available I would advise you to
2:33:24
have it available I would advise you to use it I'm not actually 100% sure why
2:33:26
use it I'm not actually 100% sure why
2:33:26
use it I'm not actually 100% sure why they don't default to it it seems fairly
2:33:28
they don't default to it it seems fairly
2:33:28
they don't default to it it seems fairly benign and
2:33:29
benign and
2:33:29
benign and harmless and also because we are using
2:33:31
harmless and also because we are using
2:33:31
harmless and also because we are using the fused implementation I think this is
2:33:33
the fused implementation I think this is
2:33:34
the fused implementation I think this is why we have dropped um notice that the
2:33:37
why we have dropped um notice that the
2:33:37
why we have dropped um notice that the running time used to be 93 milliseconds
2:33:39
running time used to be 93 milliseconds
2:33:39
running time used to be 93 milliseconds per step and we're now down to 90
2:33:41
per step and we're now down to 90
2:33:41
per step and we're now down to 90 milliseconds per step because of using
2:33:43
milliseconds per step because of using
2:33:43
milliseconds per step because of using the fused atom W Optimizer so in a
2:33:46
the fused atom W Optimizer so in a
2:33:46
the fused atom W Optimizer so in a single commit here we are introducing
2:33:48
single commit here we are introducing
2:33:48
single commit here we are introducing fused atom getting improvements on the
2:33:51
fused atom getting improvements on the
2:33:51
fused atom getting improvements on the time and we're adding or changing the
2:33:54
time and we're adding or changing the
2:33:54
time and we're adding or changing the weight Decay but we're only weight
2:33:56
weight Decay but we're only weight
2:33:56
weight Decay but we're only weight decaying the two dimensional parameters
2:33:58
decaying the two dimensional parameters
2:33:58
decaying the two dimensional parameters the embeddings and the matrices that
2:33:59
the embeddings and the matrices that
2:34:00
the embeddings and the matrices that participate in linear so that is this
2:34:03
participate in linear so that is this
2:34:03
participate in linear so that is this and we can take this out and uh yeah
2:34:06
and we can take this out and uh yeah
2:34:06
and we can take this out and uh yeah that is it for this line one more quick
2:34:10
that is it for this line one more quick
2:34:10
that is it for this line one more quick note before we continue here I just want
2:34:11
note before we continue here I just want
2:34:11
note before we continue here I just want to point out that the relationship
2:34:13
to point out that the relationship
2:34:13
to point out that the relationship between weight Decay learning rate batch
2:34:15
between weight Decay learning rate batch
2:34:15
between weight Decay learning rate batch size the atom parameters beta 1 beta 2
2:34:18
size the atom parameters beta 1 beta 2
2:34:18
size the atom parameters beta 1 beta 2 the Epsilon and so on these are very
2:34:20
the Epsilon and so on these are very
2:34:20
the Epsilon and so on these are very complicated uh mathematical
2:34:22
complicated uh mathematical
2:34:22
complicated uh mathematical relationships in the optimization
2:34:24
relationships in the optimization
2:34:24
relationships in the optimization literature and um for the most part I'm
2:34:27
literature and um for the most part I'm
2:34:27
literature and um for the most part I'm in this video I'm just trying to copy
2:34:29
in this video I'm just trying to copy
2:34:29
in this video I'm just trying to copy paste the settings that open AI used but
2:34:31
paste the settings that open AI used but
2:34:31
paste the settings that open AI used but this is a complicated topic uh quite
2:34:33
this is a complicated topic uh quite
2:34:33
this is a complicated topic uh quite deep and um yeah in this video I just
2:34:36
deep and um yeah in this video I just
2:34:36
deep and um yeah in this video I just want to copy the parameters because it's
2:34:38
want to copy the parameters because it's
2:34:38
want to copy the parameters because it's a whole different video to really talk
2:34:39
a whole different video to really talk
2:34:39
a whole different video to really talk about that in detail and give it a
2:34:41
about that in detail and give it a
2:34:41
about that in detail and give it a proper Justice instead of just high
2:34:42
proper Justice instead of just high
2:34:42
proper Justice instead of just high level
2:34:43
level
2:34:43
level intuitions uh now the next thing that I
2:34:45
intuitions uh now the next thing that I
2:34:45
intuitions uh now the next thing that I want to move on to is that uh this
2:34:48
want to move on to is that uh this
2:34:48
want to move on to is that uh this paragraph here by the way we're going to
2:34:49
paragraph here by the way we're going to
2:34:49
paragraph here by the way we're going to turn back around to when we improve our
2:34:51
turn back around to when we improve our
2:34:51
turn back around to when we improve our data loader for now I want to swing back
2:34:54
data loader for now I want to swing back
2:34:54
data loader for now I want to swing back around
2:34:56
around
2:34:56
around to this
2:35:01
table where you will notice that um for
2:35:04
table where you will notice that um for
2:35:04
table where you will notice that um for different models we of course have
2:35:06
different models we of course have
2:35:06
different models we of course have different U hyper parameters for the
2:35:08
different U hyper parameters for the
2:35:08
different U hyper parameters for the Transformer that dictate the size of the
2:35:09
Transformer that dictate the size of the
2:35:10
Transformer that dictate the size of the Transformer Network we also have a
2:35:11
Transformer Network we also have a
2:35:12
Transformer Network we also have a different learning rate so we're seeing
2:35:13
different learning rate so we're seeing
2:35:13
different learning rate so we're seeing the pattern that the bigger networks are
2:35:14
the pattern that the bigger networks are
2:35:14
the pattern that the bigger networks are trained with slightly lower learning
2:35:16
trained with slightly lower learning
2:35:16
trained with slightly lower learning rates and we also see this batch size
2:35:20
rates and we also see this batch size
2:35:20
rates and we also see this batch size where in in the small networks they use
2:35:22
where in in the small networks they use
2:35:22
where in in the small networks they use a smaller batch size and in the bigger
2:35:23
a smaller batch size and in the bigger
2:35:23
a smaller batch size and in the bigger networks they use a bigger batch size
2:35:26
networks they use a bigger batch size
2:35:26
networks they use a bigger batch size now the problem with for us is we can't
2:35:28
now the problem with for us is we can't
2:35:28
now the problem with for us is we can't just use 0.5 million batch size because
2:35:31
just use 0.5 million batch size because
2:35:31
just use 0.5 million batch size because uh if I just try to come in here and I
2:35:33
uh if I just try to come in here and I
2:35:33
uh if I just try to come in here and I try to set uh this uh B where is my
2:35:37
try to set uh this uh B where is my
2:35:38
try to set uh this uh B where is my b
2:35:40
b
2:35:40
b um b
2:35:44
equals where where do I call the DAT
2:35:46
equals where where do I call the DAT
2:35:46
equals where where do I call the DAT okay b equal 16 if I try to set um
2:35:51
okay b equal 16 if I try to set um
2:35:51
okay b equal 16 if I try to set um well well we have to be careful it's not
2:35:52
well well we have to be careful it's not
2:35:52
well well we have to be careful it's not 0.5 million because this is the badge
2:35:54
0.5 million because this is the badge
2:35:54
0.5 million because this is the badge size in the number of tokens every
2:35:56
size in the number of tokens every
2:35:56
size in the number of tokens every single one of our rows is24 tokens so
2:36:00
single one of our rows is24 tokens so
2:36:00
single one of our rows is24 tokens so 0.5 E6 1 million divide 1024 this would
2:36:04
0.5 E6 1 million divide 1024 this would
2:36:04
0.5 E6 1 million divide 1024 this would need about a
2:36:06
need about a
2:36:06
need about a 488 match size so the problem is I can't
2:36:09
488 match size so the problem is I can't
2:36:09
488 match size so the problem is I can't come in here and set this to 488 uh
2:36:12
come in here and set this to 488 uh
2:36:12
come in here and set this to 488 uh because my GPU would explode um this
2:36:15
because my GPU would explode um this
2:36:15
because my GPU would explode um this would not fit for sure and so but we
2:36:18
would not fit for sure and so but we
2:36:18
would not fit for sure and so but we still want to use this batch size
2:36:20
still want to use this batch size
2:36:20
still want to use this batch size because again as I mentioned the batch
2:36:22
because again as I mentioned the batch
2:36:22
because again as I mentioned the batch size is correlated with all the other
2:36:24
size is correlated with all the other
2:36:24
size is correlated with all the other optimization hyper parameters and the
2:36:26
optimization hyper parameters and the
2:36:26
optimization hyper parameters and the learning rates and so on so we want to
2:36:28
learning rates and so on so we want to
2:36:28
learning rates and so on so we want to have a faithful representation of all
2:36:29
have a faithful representation of all
2:36:29
have a faithful representation of all the hyper parameters and therefore we
2:36:31
the hyper parameters and therefore we
2:36:31
the hyper parameters and therefore we need to uh use a bat size of .5 million
2:36:34
need to uh use a bat size of .5 million
2:36:34
need to uh use a bat size of .5 million roughly but the question is how do we
2:36:37
roughly but the question is how do we
2:36:37
roughly but the question is how do we use .5 million if we only have a small
2:36:39
use .5 million if we only have a small
2:36:39
use .5 million if we only have a small GPU well for that we need to use what's
2:36:41
GPU well for that we need to use what's
2:36:41
GPU well for that we need to use what's called gradient accumulation uh so we're
2:36:44
called gradient accumulation uh so we're
2:36:44
called gradient accumulation uh so we're going to turn to that next and it allows
2:36:46
going to turn to that next and it allows
2:36:46
going to turn to that next and it allows us to simulate in a Serial way any
2:36:48
us to simulate in a Serial way any
2:36:48
us to simulate in a Serial way any arbitrary batch size that we set and so
2:36:51
arbitrary batch size that we set and so
2:36:51
arbitrary batch size that we set and so we can do a batch size of .5 million we
2:36:54
we can do a batch size of .5 million we
2:36:54
we can do a batch size of .5 million we just have to run longer and we have to
2:36:56
just have to run longer and we have to
2:36:56
just have to run longer and we have to process multiple sequences and basically
2:36:59
process multiple sequences and basically
2:36:59
process multiple sequences and basically add up all the gradients from them to
2:37:01
add up all the gradients from them to
2:37:02
add up all the gradients from them to simulate a batch size of .5 million so
2:37:04
simulate a batch size of .5 million so
2:37:04
simulate a batch size of .5 million so let's turn to that next okay so I
2:37:05
let's turn to that next okay so I
2:37:05
let's turn to that next okay so I started the implementation right here
2:37:07
started the implementation right here
2:37:07
started the implementation right here just by adding these lines of code and
2:37:09
just by adding these lines of code and
2:37:09
just by adding these lines of code and basically what I did is first I set the
2:37:12
basically what I did is first I set the
2:37:12
basically what I did is first I set the total batch size that we desire so this
2:37:14
total batch size that we desire so this
2:37:14
total batch size that we desire so this is exactly .5 million and I used a nice
2:37:17
is exactly .5 million and I used a nice
2:37:17
is exactly .5 million and I used a nice number a power of two uh because 2 to
2:37:19
number a power of two uh because 2 to
2:37:19
number a power of two uh because 2 to the 19 is 524 288 so it's roughly .5
2:37:23
the 19 is 524 288 so it's roughly .5
2:37:23
the 19 is 524 288 so it's roughly .5 million it's a nice number now our micro
2:37:26
million it's a nice number now our micro
2:37:26
million it's a nice number now our micro batch size as we call it now is 16 so
2:37:29
batch size as we call it now is 16 so
2:37:29
batch size as we call it now is 16 so this is going to be we still have B BYT
2:37:32
this is going to be we still have B BYT
2:37:32
this is going to be we still have B BYT in the SE that go into the Transformer
2:37:34
in the SE that go into the Transformer
2:37:34
in the SE that go into the Transformer and do forward backward but we're not
2:37:36
and do forward backward but we're not
2:37:36
and do forward backward but we're not going to do an update right we're going
2:37:38
going to do an update right we're going
2:37:38
going to do an update right we're going to do many forward backwards we're going
2:37:40
to do many forward backwards we're going
2:37:40
to do many forward backwards we're going to and those gradients are all going to
2:37:42
to and those gradients are all going to
2:37:42
to and those gradients are all going to plus equals on the parameter gradients
2:37:44
plus equals on the parameter gradients
2:37:44
plus equals on the parameter gradients they're all going to add up so we're
2:37:46
they're all going to add up so we're
2:37:46
they're all going to add up so we're going to do forward backward grad akum
2:37:48
going to do forward backward grad akum
2:37:48
going to do forward backward grad akum steps number of times and then we're
2:37:50
steps number of times and then we're
2:37:50
steps number of times and then we're going to do a single update once all
2:37:52
going to do a single update once all
2:37:52
going to do a single update once all that is
2:37:53
that is
2:37:53
that is accumulated so in particular our micro
2:37:55
accumulated so in particular our micro
2:37:55
accumulated so in particular our micro batch size is just now controlling how
2:37:58
batch size is just now controlling how
2:37:58
batch size is just now controlling how many tokens how many rows we're
2:37:59
many tokens how many rows we're
2:37:59
many tokens how many rows we're processing in a single go over a forward
2:38:02
processing in a single go over a forward
2:38:02
processing in a single go over a forward backward so um here we are doing 16 *
2:38:06
backward so um here we are doing 16 *
2:38:06
backward so um here we are doing 16 * 124 we're doing 16
2:38:09
124 we're doing 16
2:38:09
124 we're doing 16 384 um tokens per forward backward and
2:38:14
384 um tokens per forward backward and
2:38:14
384 um tokens per forward backward and we are supposed to be doing 2 to the 19
2:38:17
we are supposed to be doing 2 to the 19
2:38:17
we are supposed to be doing 2 to the 19 whoops what am I doing 2 to the
2:38:20
whoops what am I doing 2 to the
2:38:20
whoops what am I doing 2 to the 19 in total so the grat Aon will be
2:38:26
32 uh so therefore gr AUM here will work
2:38:28
32 uh so therefore gr AUM here will work
2:38:28
32 uh so therefore gr AUM here will work out to 32 and we have to do 32 forward
2:38:32
out to 32 and we have to do 32 forward
2:38:32
out to 32 and we have to do 32 forward backward um and then a single update now
2:38:35
backward um and then a single update now
2:38:35
backward um and then a single update now we see that we have about 100
2:38:37
we see that we have about 100
2:38:37
we see that we have about 100 milliseconds for a singer forward
2:38:38
milliseconds for a singer forward
2:38:38
milliseconds for a singer forward backward so doing 32 of them will be
2:38:41
backward so doing 32 of them will be
2:38:41
backward so doing 32 of them will be will make every step roughly 3 seconds
2:38:44
will make every step roughly 3 seconds
2:38:44
will make every step roughly 3 seconds just napkin
2:38:46
just napkin
2:38:46
just napkin math so that's grum steps but now we
2:38:48
math so that's grum steps but now we
2:38:48
math so that's grum steps but now we actually have to Implement that so we're
2:38:50
actually have to Implement that so we're
2:38:50
actually have to Implement that so we're going to swing over to our training Loop
2:38:54
going to swing over to our training Loop
2:38:54
going to swing over to our training Loop because now this part
2:38:56
because now this part
2:38:56
because now this part here and this part here the forward and
2:38:59
here and this part here the forward and
2:38:59
here and this part here the forward and the backward we have to now repeat this
2:39:01
the backward we have to now repeat this
2:39:01
the backward we have to now repeat this 32 times before we do everything else
2:39:04
32 times before we do everything else
2:39:04
32 times before we do everything else that follows so let's uh see how we can
2:39:06
that follows so let's uh see how we can
2:39:06
that follows so let's uh see how we can Implement that so let's come over here
2:39:09
Implement that so let's come over here
2:39:09
Implement that so let's come over here and actually we do have to load a new
2:39:10
and actually we do have to load a new
2:39:10
and actually we do have to load a new batch every single time so let me move
2:39:12
batch every single time so let me move
2:39:12
batch every single time so let me move that over here and now this is where we
2:39:14
that over here and now this is where we
2:39:14
that over here and now this is where we have the inner loop so for micro step in
2:39:18
have the inner loop so for micro step in
2:39:18
have the inner loop so for micro step in range graum
2:39:20
range graum
2:39:20
range graum steps we do this and remember that l.
2:39:24
steps we do this and remember that l.
2:39:24
steps we do this and remember that l. backward always deposits gradients so
2:39:26
backward always deposits gradients so
2:39:26
backward always deposits gradients so we're doing inside losta backward
2:39:27
we're doing inside losta backward
2:39:27
we're doing inside losta backward there's always a plus equals on the
2:39:29
there's always a plus equals on the
2:39:29
there's always a plus equals on the gradients so in every single L of
2:39:31
gradients so in every single L of
2:39:31
gradients so in every single L of backward gradients will add up on the
2:39:33
backward gradients will add up on the
2:39:33
backward gradients will add up on the gradient
2:39:35
gradient
2:39:35
gradient tensors um so we lost that backward and
2:39:38
tensors um so we lost that backward and
2:39:38
tensors um so we lost that backward and then we get all the gradients over there
2:39:41
then we get all the gradients over there
2:39:41
then we get all the gradients over there and then we normalize and everything
2:39:43
and then we normalize and everything
2:39:43
and then we normalize and everything else should just follow um so we're very
2:39:47
else should just follow um so we're very
2:39:47
else should just follow um so we're very close but actually there's like subtle
2:39:50
close but actually there's like subtle
2:39:50
close but actually there's like subtle and deep issue here and this is actually
2:39:52
and deep issue here and this is actually
2:39:52
and deep issue here and this is actually incorrect so invite I invite you to
2:39:54
incorrect so invite I invite you to
2:39:54
incorrect so invite I invite you to think about why this is not yet
2:39:56
think about why this is not yet
2:39:56
think about why this is not yet sufficient um and uh let me fix it then
2:39:59
sufficient um and uh let me fix it then
2:39:59
sufficient um and uh let me fix it then okay so I brought back the jupyter
2:40:00
okay so I brought back the jupyter
2:40:01
okay so I brought back the jupyter notebook so we can think about this
2:40:02
notebook so we can think about this
2:40:02
notebook so we can think about this carefully in a simple toy setting and
2:40:05
carefully in a simple toy setting and
2:40:05
carefully in a simple toy setting and see what's happening so let's create a
2:40:07
see what's happening so let's create a
2:40:07
see what's happening so let's create a very simple neural nut that takes a 16
2:40:10
very simple neural nut that takes a 16
2:40:10
very simple neural nut that takes a 16 Vector of 16 numbers and returns a
2:40:11
Vector of 16 numbers and returns a
2:40:11
Vector of 16 numbers and returns a single
2:40:12
single
2:40:12
single number and then here I'm creating some
2:40:15
number and then here I'm creating some
2:40:15
number and then here I'm creating some random uh examples X and some targets uh
2:40:19
random uh examples X and some targets uh
2:40:19
random uh examples X and some targets uh y Y and then we are using the mean
2:40:21
y Y and then we are using the mean
2:40:21
y Y and then we are using the mean squared loss uh here to calculate the
2:40:25
squared loss uh here to calculate the
2:40:25
squared loss uh here to calculate the loss so basically what this is is four
2:40:28
loss so basically what this is is four
2:40:28
loss so basically what this is is four individual examples and we're just doing
2:40:30
individual examples and we're just doing
2:40:30
individual examples and we're just doing Simple regression with the mean squared
2:40:31
Simple regression with the mean squared
2:40:31
Simple regression with the mean squared loss over those four
2:40:34
loss over those four
2:40:34
loss over those four examples now when we calculate the loss
2:40:36
examples now when we calculate the loss
2:40:36
examples now when we calculate the loss and we lost that backward and look at
2:40:38
and we lost that backward and look at
2:40:38
and we lost that backward and look at the gradient this is the gradient that
2:40:40
the gradient this is the gradient that
2:40:40
the gradient this is the gradient that we
2:40:41
we
2:40:41
we achieve now the loss objective here
2:40:44
achieve now the loss objective here
2:40:44
achieve now the loss objective here notice that in MSE loss the default for
2:40:46
notice that in MSE loss the default for
2:40:46
notice that in MSE loss the default for the loss function is reduction is mean
2:40:49
the loss function is reduction is mean
2:40:49
the loss function is reduction is mean so we're we're calculating the average
2:40:52
so we're we're calculating the average
2:40:52
so we're we're calculating the average mean loss um the the mean loss here over
2:40:56
mean loss um the the mean loss here over
2:40:56
mean loss um the the mean loss here over the four examples so this is the exact
2:40:59
the four examples so this is the exact
2:40:59
the four examples so this is the exact loss objective and this is the average
2:41:02
loss objective and this is the average
2:41:02
loss objective and this is the average the one over four because there are four
2:41:03
the one over four because there are four
2:41:03
the one over four because there are four independent examples here and then we
2:41:06
independent examples here and then we
2:41:06
independent examples here and then we have the four examples and their mean
2:41:08
have the four examples and their mean
2:41:08
have the four examples and their mean squared error the squared error and then
2:41:11
squared error the squared error and then
2:41:11
squared error the squared error and then this makes it the mean squared error so
2:41:14
this makes it the mean squared error so
2:41:14
this makes it the mean squared error so therefore uh we are we calculate the
2:41:16
therefore uh we are we calculate the
2:41:16
therefore uh we are we calculate the squared error and then we normalize it
2:41:18
squared error and then we normalize it
2:41:18
squared error and then we normalize it to make it the mean over the examples
2:41:20
to make it the mean over the examples
2:41:20
to make it the mean over the examples and there's four examples here so now
2:41:22
and there's four examples here so now
2:41:22
and there's four examples here so now when we come to the gradient
2:41:24
when we come to the gradient
2:41:24
when we come to the gradient accumulation version of it this uh this
2:41:28
accumulation version of it this uh this
2:41:28
accumulation version of it this uh this here is the gradient accumulation
2:41:30
here is the gradient accumulation
2:41:30
here is the gradient accumulation version of it where we have grad acum
2:41:32
version of it where we have grad acum
2:41:32
version of it where we have grad acum steps of four and I reset the gradient
2:41:35
steps of four and I reset the gradient
2:41:35
steps of four and I reset the gradient we've grum steps of four and now I'm
2:41:38
we've grum steps of four and now I'm
2:41:38
we've grum steps of four and now I'm evaluating all the examples individually
2:41:39
evaluating all the examples individually
2:41:39
evaluating all the examples individually instead and calling L that backward on
2:41:41
instead and calling L that backward on
2:41:41
instead and calling L that backward on them many times and then we're looking
2:41:43
them many times and then we're looking
2:41:43
them many times and then we're looking at the gradient that we achieve from
2:41:44
at the gradient that we achieve from
2:41:44
at the gradient that we achieve from that so basically now we forward our
2:41:47
that so basically now we forward our
2:41:47
that so basically now we forward our function calculate the exact same loss
2:41:49
function calculate the exact same loss
2:41:49
function calculate the exact same loss do a backward and we do that four times
2:41:52
do a backward and we do that four times
2:41:52
do a backward and we do that four times and when we look at the gradient uh
2:41:54
and when we look at the gradient uh
2:41:54
and when we look at the gradient uh you'll notice that the gradients don't
2:41:57
you'll notice that the gradients don't
2:41:57
you'll notice that the gradients don't match so here we uh did a single batch
2:42:00
match so here we uh did a single batch
2:42:00
match so here we uh did a single batch of four and here we did uh four gradient
2:42:03
of four and here we did uh four gradient
2:42:03
of four and here we did uh four gradient accumulation steps of batch size one and
2:42:06
accumulation steps of batch size one and
2:42:06
accumulation steps of batch size one and the gradients are not the same and
2:42:08
the gradients are not the same and
2:42:08
the gradients are not the same and basically the the reason that they're
2:42:09
basically the the reason that they're
2:42:09
basically the the reason that they're not the same is exactly because this
2:42:11
not the same is exactly because this
2:42:11
not the same is exactly because this mean squared error gets lost this one
2:42:14
mean squared error gets lost this one
2:42:14
mean squared error gets lost this one quarter in this loss gets lost because
2:42:16
quarter in this loss gets lost because
2:42:16
quarter in this loss gets lost because what happens here is the loss of
2:42:19
what happens here is the loss of
2:42:19
what happens here is the loss of objective for every one of the loops is
2:42:22
objective for every one of the loops is
2:42:22
objective for every one of the loops is just a mean squ error um which in this
2:42:25
just a mean squ error um which in this
2:42:25
just a mean squ error um which in this case because there's only a single
2:42:26
case because there's only a single
2:42:26
case because there's only a single example is just this term here so that
2:42:28
example is just this term here so that
2:42:28
example is just this term here so that was the loss in the zeroth eration same
2:42:30
was the loss in the zeroth eration same
2:42:30
was the loss in the zeroth eration same in the first third and so on and then
2:42:33
in the first third and so on and then
2:42:33
in the first third and so on and then when you do the loss. backward we're
2:42:35
when you do the loss. backward we're
2:42:35
when you do the loss. backward we're accumulating gradients and what happens
2:42:38
accumulating gradients and what happens
2:42:38
accumulating gradients and what happens is that accumulation in the gradient is
2:42:40
is that accumulation in the gradient is
2:42:40
is that accumulation in the gradient is basically equivalent to doing a sum in
2:42:43
basically equivalent to doing a sum in
2:42:43
basically equivalent to doing a sum in the
2:42:45
the
2:42:45
the loss so our loss actually here is this
2:42:49
loss so our loss actually here is this
2:42:49
loss so our loss actually here is this without the factor of one quarter
2:42:51
without the factor of one quarter
2:42:51
without the factor of one quarter outside of it so we're missing the
2:42:54
outside of it so we're missing the
2:42:54
outside of it so we're missing the normalizer and therefore our gradients
2:42:56
normalizer and therefore our gradients
2:42:56
normalizer and therefore our gradients are off and so the way to fix this or
2:42:58
are off and so the way to fix this or
2:42:58
are off and so the way to fix this or one of them is basically we can actually
2:43:00
one of them is basically we can actually
2:43:00
one of them is basically we can actually come here and we can say loss equals
2:43:02
come here and we can say loss equals
2:43:02
come here and we can say loss equals loss divide
2:43:04
loss divide
2:43:04
loss divide 4 and what happens now is that we're
2:43:07
4 and what happens now is that we're
2:43:07
4 and what happens now is that we're introducing we're we're scaling our loss
2:43:09
introducing we're we're scaling our loss
2:43:09
introducing we're we're scaling our loss we're introducing a one quarter in front
2:43:11
we're introducing a one quarter in front
2:43:11
we're introducing a one quarter in front of all of these
2:43:14
places so all the individual losses are
2:43:17
places so all the individual losses are
2:43:17
places so all the individual losses are now scaled by one quarter and and then
2:43:19
now scaled by one quarter and and then
2:43:19
now scaled by one quarter and and then when we backward all of these accumulate
2:43:21
when we backward all of these accumulate
2:43:22
when we backward all of these accumulate with a sum but now there's a one quarter
2:43:24
with a sum but now there's a one quarter
2:43:24
with a sum but now there's a one quarter inside every one of these components and
2:43:26
inside every one of these components and
2:43:26
inside every one of these components and now our losses will be
2:43:28
now our losses will be
2:43:28
now our losses will be equivalent so when I run this you see
2:43:32
equivalent so when I run this you see
2:43:32
equivalent so when I run this you see that the U gradients are now identical
2:43:35
that the U gradients are now identical
2:43:35
that the U gradients are now identical so long story short with this simple
2:43:37
so long story short with this simple
2:43:37
so long story short with this simple example uh when you step through it you
2:43:39
example uh when you step through it you
2:43:39
example uh when you step through it you can see that basically the reason that
2:43:41
can see that basically the reason that
2:43:41
can see that basically the reason that this is not correct is because in the
2:43:44
this is not correct is because in the
2:43:44
this is not correct is because in the same way as here in the MSE loss the
2:43:46
same way as here in the MSE loss the
2:43:46
same way as here in the MSE loss the loss that we're calculating here in the
2:43:50
loss that we're calculating here in the
2:43:50
loss that we're calculating here in the model is using a reduction of mean as
2:43:54
model is using a reduction of mean as
2:43:54
model is using a reduction of mean as well uh so where's the loss after that
2:43:57
well uh so where's the loss after that
2:43:57
well uh so where's the loss after that cross
2:43:58
cross
2:43:58
cross entropy and by default the reduction uh
2:44:01
entropy and by default the reduction uh
2:44:01
entropy and by default the reduction uh here in Cross entropy is also I don't
2:44:03
here in Cross entropy is also I don't
2:44:03
here in Cross entropy is also I don't know why they don't show it but it's the
2:44:05
know why they don't show it but it's the
2:44:05
know why they don't show it but it's the mean uh the mean uh loss at all the B
2:44:08
mean uh the mean uh loss at all the B
2:44:08
mean uh the mean uh loss at all the B BYT elements
2:44:10
BYT elements
2:44:10
BYT elements right so there's a reduction by mean in
2:44:13
right so there's a reduction by mean in
2:44:13
right so there's a reduction by mean in there and if we're just doing this
2:44:15
there and if we're just doing this
2:44:15
there and if we're just doing this gradient accumulation here we're missing
2:44:16
gradient accumulation here we're missing
2:44:16
gradient accumulation here we're missing that and so the way to fix this is to
2:44:19
that and so the way to fix this is to
2:44:19
that and so the way to fix this is to simply compensate for the number of
2:44:21
simply compensate for the number of
2:44:21
simply compensate for the number of gradient accumulation steps and we can
2:44:22
gradient accumulation steps and we can
2:44:23
gradient accumulation steps and we can in the same way divide this loss so in
2:44:25
in the same way divide this loss so in
2:44:25
in the same way divide this loss so in particular here the number of steps that
2:44:26
particular here the number of steps that
2:44:26
particular here the number of steps that we're doing is loss equals loss divide
2:44:31
we're doing is loss equals loss divide
2:44:31
we're doing is loss equals loss divide gradient accumulation steps so even uh
2:44:33
gradient accumulation steps so even uh
2:44:33
gradient accumulation steps so even uh co-pilot s gets the modification but in
2:44:36
co-pilot s gets the modification but in
2:44:36
co-pilot s gets the modification but in the same way exactly we are scaling down
2:44:38
the same way exactly we are scaling down
2:44:38
the same way exactly we are scaling down the loss so that when we do loss that
2:44:40
the loss so that when we do loss that
2:44:40
the loss so that when we do loss that backward which basically corresponds to
2:44:42
backward which basically corresponds to
2:44:42
backward which basically corresponds to a sum in the objective we are summing up
2:44:45
a sum in the objective we are summing up
2:44:45
a sum in the objective we are summing up the already
2:44:46
the already
2:44:46
the already normalized um loss and and therefore
2:44:49
normalized um loss and and therefore
2:44:49
normalized um loss and and therefore when we sum up the losses divided by
2:44:51
when we sum up the losses divided by
2:44:51
when we sum up the losses divided by grum steps we are recovering the
2:44:53
grum steps we are recovering the
2:44:53
grum steps we are recovering the additional normalizer uh and so now
2:44:56
additional normalizer uh and so now
2:44:56
additional normalizer uh and so now these two will be now this will be
2:44:59
these two will be now this will be
2:44:59
these two will be now this will be equivalent to the original uh sort of
2:45:01
equivalent to the original uh sort of
2:45:01
equivalent to the original uh sort of optimization because the gradient will
2:45:03
optimization because the gradient will
2:45:03
optimization because the gradient will come out the same okay so I had to do a
2:45:05
come out the same okay so I had to do a
2:45:05
come out the same okay so I had to do a few more touch-ups and I launched
2:45:07
few more touch-ups and I launched
2:45:07
few more touch-ups and I launched launched the optimization here so in
2:45:09
launched the optimization here so in
2:45:09
launched the optimization here so in particular one thing we want to do
2:45:10
particular one thing we want to do
2:45:10
particular one thing we want to do because we want to print things nicely
2:45:13
because we want to print things nicely
2:45:13
because we want to print things nicely is well first of all we need to create
2:45:15
is well first of all we need to create
2:45:15
is well first of all we need to create like an accumulator over the loss we
2:45:16
like an accumulator over the loss we
2:45:16
like an accumulator over the loss we can't just print the loss because we'd
2:45:18
can't just print the loss because we'd
2:45:18
can't just print the loss because we'd be printing only the final loss at the
2:45:20
be printing only the final loss at the
2:45:20
be printing only the final loss at the final micro step so instead we have loss
2:45:22
final micro step so instead we have loss
2:45:22
final micro step so instead we have loss ofon which I initialize at zero and then
2:45:25
ofon which I initialize at zero and then
2:45:25
ofon which I initialize at zero and then I accumulate a uh the loss into it and
2:45:28
I accumulate a uh the loss into it and
2:45:28
I accumulate a uh the loss into it and I'm using detach so that um uh I'm
2:45:31
I'm using detach so that um uh I'm
2:45:31
I'm using detach so that um uh I'm detaching the tensor uh from the graph
2:45:35
detaching the tensor uh from the graph
2:45:35
detaching the tensor uh from the graph and I'm just trying to keep track of the
2:45:36
and I'm just trying to keep track of the
2:45:36
and I'm just trying to keep track of the values so I'm making these Leaf nodes
2:45:38
values so I'm making these Leaf nodes
2:45:38
values so I'm making these Leaf nodes when I add them so that's lakum and then
2:45:42
when I add them so that's lakum and then
2:45:42
when I add them so that's lakum and then we're printing that here instead of loss
2:45:43
we're printing that here instead of loss
2:45:43
we're printing that here instead of loss and then in addition to that I had to
2:45:45
and then in addition to that I had to
2:45:46
and then in addition to that I had to account for the grum steps inside the
2:45:47
account for the grum steps inside the
2:45:48
account for the grum steps inside the tokens processed because now the tokens
2:45:50
tokens processed because now the tokens
2:45:50
tokens processed because now the tokens processed per step is B * T * gradient
2:45:54
processed per step is B * T * gradient
2:45:54
processed per step is B * T * gradient accumulation so long story short here we
2:45:57
accumulation so long story short here we
2:45:57
accumulation so long story short here we have the optimization it looks uh
2:45:59
have the optimization it looks uh
2:45:59
have the optimization it looks uh reasonable right we're starting at a
2:46:00
reasonable right we're starting at a
2:46:00
reasonable right we're starting at a good spot we calculated the grum steps
2:46:03
good spot we calculated the grum steps
2:46:03
good spot we calculated the grum steps to be
2:46:04
to be
2:46:04
to be 32 and uh we're getting about 3 seconds
2:46:07
32 and uh we're getting about 3 seconds
2:46:07
32 and uh we're getting about 3 seconds here
2:46:08
here
2:46:08
here right
2:46:09
right
2:46:10
right um
2:46:12
um
2:46:12
um and so this looks pretty good now if
2:46:14
and so this looks pretty good now if
2:46:14
and so this looks pretty good now if you'd like to verify that uh your
2:46:16
you'd like to verify that uh your
2:46:16
you'd like to verify that uh your optimization and the implementation here
2:46:18
optimization and the implementation here
2:46:18
optimization and the implementation here is correct and your working on a side
2:46:20
is correct and your working on a side
2:46:20
is correct and your working on a side well now because we have the total patch
2:46:21
well now because we have the total patch
2:46:21
well now because we have the total patch size and the gradient accumulation steps
2:46:23
size and the gradient accumulation steps
2:46:24
size and the gradient accumulation steps our setting of B is purely a performance
2:46:26
our setting of B is purely a performance
2:46:26
our setting of B is purely a performance optimization kind of setting so if you
2:46:29
optimization kind of setting so if you
2:46:29
optimization kind of setting so if you have a big GPU you can actually increase
2:46:31
have a big GPU you can actually increase
2:46:31
have a big GPU you can actually increase this to 32 and you'll probably go a bit
2:46:33
this to 32 and you'll probably go a bit
2:46:33
this to 32 and you'll probably go a bit faster if you have a very small GPU you
2:46:35
faster if you have a very small GPU you
2:46:35
faster if you have a very small GPU you can try eight or four but in any case
2:46:37
can try eight or four but in any case
2:46:37
can try eight or four but in any case you should be getting the exact same
2:46:38
you should be getting the exact same
2:46:38
you should be getting the exact same optimization and the same answers up to
2:46:41
optimization and the same answers up to
2:46:41
optimization and the same answers up to like a floating Point error because the
2:46:43
like a floating Point error because the
2:46:43
like a floating Point error because the gradient accumulation kicks in and um
2:46:46
gradient accumulation kicks in and um
2:46:46
gradient accumulation kicks in and um and can um handle everything serially as
2:46:48
and can um handle everything serially as
2:46:48
and can um handle everything serially as an
2:46:49
an
2:46:49
an Neary so uh that's it for gradient
2:46:51
Neary so uh that's it for gradient
2:46:51
Neary so uh that's it for gradient accumulation I think okay so now is the
2:46:53
accumulation I think okay so now is the
2:46:53
accumulation I think okay so now is the time to bring out the heavy weapons uh
2:46:56
time to bring out the heavy weapons uh
2:46:56
time to bring out the heavy weapons uh you've noticed that so far we've only
2:46:57
you've noticed that so far we've only
2:46:57
you've noticed that so far we've only been using a single GPU for training but
2:47:00
been using a single GPU for training but
2:47:00
been using a single GPU for training but actually I am paying for eight gpus here
2:47:02
actually I am paying for eight gpus here
2:47:02
actually I am paying for eight gpus here and so uh we should be putting all of
2:47:04
and so uh we should be putting all of
2:47:04
and so uh we should be putting all of them to work and in particular they are
2:47:06
them to work and in particular they are
2:47:06
them to work and in particular they are going to collaborate and uh you know
2:47:09
going to collaborate and uh you know
2:47:09
going to collaborate and uh you know optimize over tokens at the same time
2:47:12
optimize over tokens at the same time
2:47:12
optimize over tokens at the same time and communicate so that um uh they're
2:47:15
and communicate so that um uh they're
2:47:15
and communicate so that um uh they're all kind of collaborating on the
2:47:16
all kind of collaborating on the
2:47:16
all kind of collaborating on the optimization for this we are going to be
2:47:18
optimization for this we are going to be
2:47:18
optimization for this we are going to be using the distributed data parallel from
2:47:20
using the distributed data parallel from
2:47:20
using the distributed data parallel from pytorch there's also a legacy data
2:47:22
pytorch there's also a legacy data
2:47:22
pytorch there's also a legacy data parallel which I recommend you not use
2:47:24
parallel which I recommend you not use
2:47:24
parallel which I recommend you not use and that's kind of like you know Legacy
2:47:27
and that's kind of like you know Legacy
2:47:27
and that's kind of like you know Legacy distributed data parallel Works in a
2:47:28
distributed data parallel Works in a
2:47:28
distributed data parallel Works in a very simple way we have eight gpus so
2:47:31
very simple way we have eight gpus so
2:47:31
very simple way we have eight gpus so we're going to uh launch eight processes
2:47:34
we're going to uh launch eight processes
2:47:35
we're going to uh launch eight processes and each process is going to be assigned
2:47:36
and each process is going to be assigned
2:47:36
and each process is going to be assigned to GPU and for each process the training
2:47:40
to GPU and for each process the training
2:47:40
to GPU and for each process the training Loop and everything we've worked on so
2:47:41
Loop and everything we've worked on so
2:47:41
Loop and everything we've worked on so far is going to look pretty much the
2:47:42
far is going to look pretty much the
2:47:42
far is going to look pretty much the same H GPU as far as it's concerned is
2:47:45
same H GPU as far as it's concerned is
2:47:45
same H GPU as far as it's concerned is just working on exactly what we've built
2:47:47
just working on exactly what we've built
2:47:47
just working on exactly what we've built so far but now Secret L there's eight of
2:47:49
so far but now Secret L there's eight of
2:47:49
so far but now Secret L there's eight of them and they're all going to be
2:47:51
them and they're all going to be
2:47:51
them and they're all going to be processing slightly different parts of
2:47:52
processing slightly different parts of
2:47:52
processing slightly different parts of the data and we're going to add one more
2:47:56
the data and we're going to add one more
2:47:56
the data and we're going to add one more part where once they all calculate their
2:47:58
part where once they all calculate their
2:47:58
part where once they all calculate their gradients there's one more part where we
2:48:00
gradients there's one more part where we
2:48:00
gradients there's one more part where we do a average of those
2:48:03
do a average of those
2:48:03
do a average of those gradients and so that's how they're
2:48:05
gradients and so that's how they're
2:48:05
gradients and so that's how they're going to be collaborating on uh the
2:48:07
going to be collaborating on uh the
2:48:07
going to be collaborating on uh the computational workload here so to use
2:48:10
computational workload here so to use
2:48:10
computational workload here so to use all eight of them we're not going to be
2:48:12
all eight of them we're not going to be
2:48:12
all eight of them we're not going to be launching our script anymore with just
2:48:14
launching our script anymore with just
2:48:14
launching our script anymore with just um pytorch train
2:48:16
um pytorch train
2:48:16
um pytorch train gbt2 piy we're going to be running it
2:48:19
gbt2 piy we're going to be running it
2:48:19
gbt2 piy we're going to be running it with a special command called torrun in
2:48:20
with a special command called torrun in
2:48:21
with a special command called torrun in pytorch we'll see that in a bit and
2:48:23
pytorch we'll see that in a bit and
2:48:23
pytorch we'll see that in a bit and torrun uh when it runs our python script
2:48:26
torrun uh when it runs our python script
2:48:26
torrun uh when it runs our python script we'll actually make sure to run eight
2:48:28
we'll actually make sure to run eight
2:48:28
we'll actually make sure to run eight eight of them in parallel and it creates
2:48:32
eight of them in parallel and it creates
2:48:32
eight of them in parallel and it creates these environmental variables where each
2:48:34
these environmental variables where each
2:48:34
these environmental variables where each of these processes can look up which uh
2:48:37
of these processes can look up which uh
2:48:37
of these processes can look up which uh basically which one of the processes it
2:48:40
basically which one of the processes it
2:48:40
basically which one of the processes it is so for example torron will set rank
2:48:43
is so for example torron will set rank
2:48:43
is so for example torron will set rank local Rank and World size environmental
2:48:46
local Rank and World size environmental
2:48:46
local Rank and World size environmental variables and so this is a bad way to
2:48:48
variables and so this is a bad way to
2:48:48
variables and so this is a bad way to detect whether uh DDP is running so if
2:48:51
detect whether uh DDP is running so if
2:48:51
detect whether uh DDP is running so if we're using torch run if DDP is
2:48:54
we're using torch run if DDP is
2:48:54
we're using torch run if DDP is running then uh we have to make sure
2:48:57
running then uh we have to make sure
2:48:57
running then uh we have to make sure that K is available because I don't know
2:48:58
that K is available because I don't know
2:48:58
that K is available because I don't know that you can run this on CPU anymore or
2:49:01
that you can run this on CPU anymore or
2:49:01
that you can run this on CPU anymore or that that makes sense to do um this is
2:49:04
that that makes sense to do um this is
2:49:05
that that makes sense to do um this is some um setup code here the important
2:49:07
some um setup code here the important
2:49:07
some um setup code here the important part is that there's a world size which
2:49:10
part is that there's a world size which
2:49:10
part is that there's a world size which for us will be eight that's the total
2:49:11
for us will be eight that's the total
2:49:11
for us will be eight that's the total number of processes running there's a
2:49:14
number of processes running there's a
2:49:14
number of processes running there's a rank which is um each process will
2:49:17
rank which is um each process will
2:49:17
rank which is um each process will basically run the ex exact same code at
2:49:19
basically run the ex exact same code at
2:49:19
basically run the ex exact same code at the exact same time roughly but all the
2:49:22
the exact same time roughly but all the
2:49:22
the exact same time roughly but all the process the only difference between
2:49:24
process the only difference between
2:49:24
process the only difference between these processes is that they all have a
2:49:25
these processes is that they all have a
2:49:26
these processes is that they all have a different dtp rank so the um gpu0 will
2:49:30
different dtp rank so the um gpu0 will
2:49:30
different dtp rank so the um gpu0 will have DDP rank of zero GPU 1 will have uh
2:49:33
have DDP rank of zero GPU 1 will have uh
2:49:33
have DDP rank of zero GPU 1 will have uh rank of one Etc so otherwise they're all
2:49:36
rank of one Etc so otherwise they're all
2:49:36
rank of one Etc so otherwise they're all running the exact same script it's just
2:49:38
running the exact same script it's just
2:49:38
running the exact same script it's just that DDP rank will be a slightly
2:49:40
that DDP rank will be a slightly
2:49:40
that DDP rank will be a slightly different integer and that is the way
2:49:42
different integer and that is the way
2:49:42
different integer and that is the way for us to coordinate that they don't for
2:49:44
for us to coordinate that they don't for
2:49:44
for us to coordinate that they don't for example run on the same data we want to
2:49:46
example run on the same data we want to
2:49:46
example run on the same data we want to we want them to run on different parts
2:49:47
we want them to run on different parts
2:49:47
we want them to run on different parts of the data and so on
2:49:49
of the data and so on
2:49:49
of the data and so on now local rank is something that is only
2:49:52
now local rank is something that is only
2:49:52
now local rank is something that is only used in a multi- node setting we only
2:49:54
used in a multi- node setting we only
2:49:54
used in a multi- node setting we only have a single node with ag gpus and so
2:49:57
have a single node with ag gpus and so
2:49:57
have a single node with ag gpus and so local rank is the rank of the GPU on a
2:50:00
local rank is the rank of the GPU on a
2:50:00
local rank is the rank of the GPU on a single node so from 0 to seven as an
2:50:04
single node so from 0 to seven as an
2:50:04
single node so from 0 to seven as an example but for us we're mostly going to
2:50:06
example but for us we're mostly going to
2:50:06
example but for us we're mostly going to be running on a single box so the things
2:50:08
be running on a single box so the things
2:50:08
be running on a single box so the things we care about are Rank and World size
2:50:10
we care about are Rank and World size
2:50:10
we care about are Rank and World size this is eight and this will be whatever
2:50:12
this is eight and this will be whatever
2:50:12
this is eight and this will be whatever it is depending on the GPU uh that uh
2:50:15
it is depending on the GPU uh that uh
2:50:15
it is depending on the GPU uh that uh that this particular instantiation of
2:50:17
that this particular instantiation of
2:50:17
that this particular instantiation of the script runs on
2:50:19
the script runs on
2:50:19
the script runs on now here we make sure that according to
2:50:23
now here we make sure that according to
2:50:23
now here we make sure that according to the local rank we are setting the device
2:50:27
the local rank we are setting the device
2:50:27
the local rank we are setting the device to be Cuda colon and colon indicates
2:50:30
to be Cuda colon and colon indicates
2:50:30
to be Cuda colon and colon indicates which GPU to use if there are more than
2:50:32
which GPU to use if there are more than
2:50:32
which GPU to use if there are more than one gpus so depending on the local rank
2:50:36
one gpus so depending on the local rank
2:50:36
one gpus so depending on the local rank of this process it's going to use just
2:50:39
of this process it's going to use just
2:50:39
of this process it's going to use just the appropriate GPU so there's no
2:50:40
the appropriate GPU so there's no
2:50:40
the appropriate GPU so there's no collisions on which GPU is being used by
2:50:42
collisions on which GPU is being used by
2:50:42
collisions on which GPU is being used by which
2:50:43
which
2:50:43
which process and finally there's a Boolean
2:50:45
process and finally there's a Boolean
2:50:45
process and finally there's a Boolean variable that I like to create which is
2:50:47
variable that I like to create which is
2:50:47
variable that I like to create which is the DDP rank equ equal Z so the master
2:50:50
the DDP rank equ equal Z so the master
2:50:50
the DDP rank equ equal Z so the master process is arbitrarily process number
2:50:53
process is arbitrarily process number
2:50:53
process is arbitrarily process number zero and it does a lot of the printing
2:50:55
zero and it does a lot of the printing
2:50:55
zero and it does a lot of the printing logging checkpointing Etc and the other
2:50:57
logging checkpointing Etc and the other
2:50:57
logging checkpointing Etc and the other processes are thought of mostly as a
2:50:59
processes are thought of mostly as a
2:50:59
processes are thought of mostly as a compute processes that are assisting and
2:51:01
compute processes that are assisting and
2:51:01
compute processes that are assisting and so Master process zero will have some
2:51:03
so Master process zero will have some
2:51:03
so Master process zero will have some additional work to do all the other
2:51:05
additional work to do all the other
2:51:05
additional work to do all the other processes will uh will mostly just be
2:51:06
processes will uh will mostly just be
2:51:06
processes will uh will mostly just be doing forward
2:51:07
doing forward
2:51:08
doing forward backwards and if we're not using DDP and
2:51:10
backwards and if we're not using DDP and
2:51:10
backwards and if we're not using DDP and none of these variables are set we
2:51:11
none of these variables are set we
2:51:12
none of these variables are set we revert back to single GPU training so
2:51:14
revert back to single GPU training so
2:51:14
revert back to single GPU training so that means that we only have rank zero
2:51:16
that means that we only have rank zero
2:51:16
that means that we only have rank zero the world size is just one uh and and we
2:51:19
the world size is just one uh and and we
2:51:19
the world size is just one uh and and we are the master process and we try to
2:51:21
are the master process and we try to
2:51:21
are the master process and we try to autodetect the device and this is world
2:51:24
autodetect the device and this is world
2:51:24
autodetect the device and this is world as
2:51:25
as
2:51:25
as normal so so far all we've done is we've
2:51:27
normal so so far all we've done is we've
2:51:27
normal so so far all we've done is we've initialized
2:51:28
initialized
2:51:28
initialized DDP and uh in the case where we're
2:51:31
DDP and uh in the case where we're
2:51:31
DDP and uh in the case where we're running with torrun which we'll see in a
2:51:33
running with torrun which we'll see in a
2:51:33
running with torrun which we'll see in a bit there's going to be eight copies
2:51:35
bit there's going to be eight copies
2:51:35
bit there's going to be eight copies running in parallel each one of them
2:51:37
running in parallel each one of them
2:51:37
running in parallel each one of them will have a different Rank and now we
2:51:39
will have a different Rank and now we
2:51:39
will have a different Rank and now we have to make sure that everything
2:51:41
have to make sure that everything
2:51:41
have to make sure that everything happens uh correctly afterwards so the
2:51:44
happens uh correctly afterwards so the
2:51:44
happens uh correctly afterwards so the tricky thing with running multiple
2:51:45
tricky thing with running multiple
2:51:45
tricky thing with running multiple processes is you always have to imagine
2:51:48
processes is you always have to imagine
2:51:48
processes is you always have to imagine that there's going to be eight processes
2:51:50
that there's going to be eight processes
2:51:50
that there's going to be eight processes running in parallel so as you read the
2:51:52
running in parallel so as you read the
2:51:52
running in parallel so as you read the code now you have to imagine there's
2:51:54
code now you have to imagine there's
2:51:54
code now you have to imagine there's eight you know eight python interpreters
2:51:57
eight you know eight python interpreters
2:51:57
eight you know eight python interpreters running down these lines of code and the
2:51:59
running down these lines of code and the
2:51:59
running down these lines of code and the only difference between them is that
2:52:01
only difference between them is that
2:52:01
only difference between them is that they have a different DDP rank so they
2:52:03
they have a different DDP rank so they
2:52:03
they have a different DDP rank so they all come here they all pick the exact
2:52:05
all come here they all pick the exact
2:52:05
all come here they all pick the exact same seed they all make all of these
2:52:08
same seed they all make all of these
2:52:08
same seed they all make all of these calculations completely unaware of the
2:52:10
calculations completely unaware of the
2:52:10
calculations completely unaware of the other copies running roughly speaking
2:52:11
other copies running roughly speaking
2:52:12
other copies running roughly speaking right so they all make the exact same
2:52:14
right so they all make the exact same
2:52:14
right so they all make the exact same calculations and now we have to adjust
2:52:16
calculations and now we have to adjust
2:52:16
calculations and now we have to adjust these calculations to take into account
2:52:19
these calculations to take into account
2:52:19
these calculations to take into account that there's actually like a certain
2:52:21
that there's actually like a certain
2:52:21
that there's actually like a certain world size and certain ranks so in
2:52:24
world size and certain ranks so in
2:52:24
world size and certain ranks so in particular these micro batches and
2:52:26
particular these micro batches and
2:52:26
particular these micro batches and sequence lengths these are all just per
2:52:28
sequence lengths these are all just per
2:52:28
sequence lengths these are all just per GPU right so now there's going to be num
2:52:31
GPU right so now there's going to be num
2:52:31
GPU right so now there's going to be num processes of them running in parallel so
2:52:34
processes of them running in parallel so
2:52:34
processes of them running in parallel so we have to adjust this right because the
2:52:36
we have to adjust this right because the
2:52:36
we have to adjust this right because the grum steps now is going to be total B
2:52:39
grum steps now is going to be total B
2:52:39
grum steps now is going to be total B size divide B * T time U DDP R
2:52:43
size divide B * T time U DDP R
2:52:43
size divide B * T time U DDP R size because each um process will will
2:52:48
size because each um process will will
2:52:48
size because each um process will will do B * T and there's this many of
2:52:51
do B * T and there's this many of
2:52:51
do B * T and there's this many of them and so in addition to that we we
2:52:54
them and so in addition to that we we
2:52:54
them and so in addition to that we we want to make sure that this fits nicely
2:52:56
want to make sure that this fits nicely
2:52:56
want to make sure that this fits nicely into total batch size which for us it
2:52:58
into total batch size which for us it
2:52:58
into total batch size which for us it will because 16 * 124 * 8 8 gpus is
2:53:04
will because 16 * 124 * 8 8 gpus is
2:53:04
will because 16 * 124 * 8 8 gpus is 131 uh K and so
2:53:08
131 uh K and so
2:53:08
131 uh K and so 524288 this means that our gratum will
2:53:10
524288 this means that our gratum will
2:53:10
524288 this means that our gratum will be four with the current settings right
2:53:13
be four with the current settings right
2:53:13
be four with the current settings right so there's going to be 16 * 124 process
2:53:16
so there's going to be 16 * 124 process
2:53:16
so there's going to be 16 * 124 process on each GPU and then there's a GP pus so
2:53:18
on each GPU and then there's a GP pus so
2:53:18
on each GPU and then there's a GP pus so we're going to be doing
2:53:20
we're going to be doing
2:53:20
we're going to be doing 131,000 tokens in a single forward
2:53:23
131,000 tokens in a single forward
2:53:23
131,000 tokens in a single forward backward on the 8
2:53:26
backward on the 8
2:53:26
backward on the 8 gpus so we want to make sure that this
2:53:28
gpus so we want to make sure that this
2:53:28
gpus so we want to make sure that this fits nicely so that we can derive a nice
2:53:30
fits nicely so that we can derive a nice
2:53:30
fits nicely so that we can derive a nice gradient accumulation
2:53:32
gradient accumulation
2:53:32
gradient accumulation steps and uh yeah let's just adjust the
2:53:36
steps and uh yeah let's just adjust the
2:53:36
steps and uh yeah let's just adjust the comments here times uh DDP World size
2:53:41
comments here times uh DDP World size
2:53:41
comments here times uh DDP World size okay so each GPU calculates this now
2:53:45
okay so each GPU calculates this now
2:53:45
okay so each GPU calculates this now this is where we start to get run into
2:53:46
this is where we start to get run into
2:53:46
this is where we start to get run into issues right so we are each process is
2:53:49
issues right so we are each process is
2:53:49
issues right so we are each process is going to come by a print and they're all
2:53:51
going to come by a print and they're all
2:53:51
going to come by a print and they're all going to print so we're going to have
2:53:53
going to print so we're going to have
2:53:53
going to print so we're going to have eight copies of these prints so one way
2:53:56
eight copies of these prints so one way
2:53:56
eight copies of these prints so one way to deal with this is exactly this master
2:53:58
to deal with this is exactly this master
2:53:58
to deal with this is exactly this master process variable that we have so if
2:53:59
process variable that we have so if
2:54:00
process variable that we have so if Master process then guard this and
2:54:03
Master process then guard this and
2:54:03
Master process then guard this and that's just so that we just print this a
2:54:05
that's just so that we just print this a
2:54:05
that's just so that we just print this a single time because otherwise all the
2:54:07
single time because otherwise all the
2:54:07
single time because otherwise all the processes would have computed the exact
2:54:08
processes would have computed the exact
2:54:08
processes would have computed the exact same variables and there's no need to
2:54:10
same variables and there's no need to
2:54:10
same variables and there's no need to print this eight
2:54:11
print this eight
2:54:11
print this eight times um before getting into the data
2:54:14
times um before getting into the data
2:54:14
times um before getting into the data loader and we're going to have to
2:54:15
loader and we're going to have to
2:54:15
loader and we're going to have to refactor it obviously maybe at this
2:54:18
refactor it obviously maybe at this
2:54:18
refactor it obviously maybe at this point is uh we should do some prints and
2:54:21
point is uh we should do some prints and
2:54:21
point is uh we should do some prints and uh just take it out for a spin and exit
2:54:23
uh just take it out for a spin and exit
2:54:23
uh just take it out for a spin and exit at this point so import
2:54:26
at this point so import
2:54:26
at this point so import sis and S start exit and print IM
2:54:33
GPU um DDP
2:54:38
GPU um DDP
2:54:38
GPU um DDP rank IM GPU DDP Rank and that um
2:54:43
rank IM GPU DDP Rank and that um
2:54:43
rank IM GPU DDP Rank and that um print
2:54:46
print
2:54:46
print by so uh so now let's try to run this
2:54:49
by so uh so now let's try to run this
2:54:49
by so uh so now let's try to run this and just see how this works so let's
2:54:51
and just see how this works so let's
2:54:51
and just see how this works so let's take it for a spin just so we see what
2:54:52
take it for a spin just so we see what
2:54:52
take it for a spin just so we see what it looks like so normally we use to
2:54:54
it looks like so normally we use to
2:54:54
it looks like so normally we use to launch python train gpd2 P like this now
2:54:57
launch python train gpd2 P like this now
2:54:57
launch python train gpd2 P like this now we're going to run with torch run and
2:54:59
we're going to run with torch run and
2:54:59
we're going to run with torch run and this is what it looks like so torch run
2:55:02
this is what it looks like so torch run
2:55:02
this is what it looks like so torch run Standalone number of processes for
2:55:03
Standalone number of processes for
2:55:04
Standalone number of processes for example is eight for us because we have
2:55:05
example is eight for us because we have
2:55:05
example is eight for us because we have eight gpus uh and then change of2 Pi so
2:55:09
eight gpus uh and then change of2 Pi so
2:55:09
eight gpus uh and then change of2 Pi so this is what the command would look like
2:55:11
this is what the command would look like
2:55:11
this is what the command would look like and torch run again we'll run eight of
2:55:13
and torch run again we'll run eight of
2:55:13
and torch run again we'll run eight of these so let's just see what happens so
2:55:16
these so let's just see what happens so
2:55:16
these so let's just see what happens so first
2:55:18
first
2:55:18
first it gets a little busy so there's a lot
2:55:20
it gets a little busy so there's a lot
2:55:20
it gets a little busy so there's a lot going on here so first of all there's
2:55:22
going on here so first of all there's
2:55:22
going on here so first of all there's some warnings from distributed and I
2:55:24
some warnings from distributed and I
2:55:24
some warnings from distributed and I don't actually know that these mean
2:55:26
don't actually know that these mean
2:55:26
don't actually know that these mean anything I think this is just like the
2:55:28
anything I think this is just like the
2:55:28
anything I think this is just like the code is setting up and the processes are
2:55:29
code is setting up and the processes are
2:55:29
code is setting up and the processes are coming online and we're seeing some
2:55:31
coming online and we're seeing some
2:55:31
coming online and we're seeing some preliminary failure to collect while the
2:55:33
preliminary failure to collect while the
2:55:33
preliminary failure to collect while the processes come up I'm not 100% sure
2:55:36
processes come up I'm not 100% sure
2:55:36
processes come up I'm not 100% sure about that but we start to then get into
2:55:39
about that but we start to then get into
2:55:39
about that but we start to then get into actual prints
2:55:41
actual prints
2:55:41
actual prints so all the processes went down and then
2:55:44
so all the processes went down and then
2:55:44
so all the processes went down and then the first print actually comes from
2:55:46
the first print actually comes from
2:55:46
the first print actually comes from process 5 uh just by chance and then it
2:55:50
process 5 uh just by chance and then it
2:55:50
process 5 uh just by chance and then it printed so process 5 basically got here
2:55:52
printed so process 5 basically got here
2:55:52
printed so process 5 basically got here first it said I'm process on GPU 5 buy
2:55:56
first it said I'm process on GPU 5 buy
2:55:56
first it said I'm process on GPU 5 buy and then this these prints come from the
2:56:00
and then this these prints come from the
2:56:00
and then this these prints come from the master
2:56:01
master
2:56:01
master process so process 5 just finished first
2:56:04
process so process 5 just finished first
2:56:04
process so process 5 just finished first for whatever reason it just depends on
2:56:05
for whatever reason it just depends on
2:56:05
for whatever reason it just depends on how the operating system scheduled the
2:56:07
how the operating system scheduled the
2:56:07
how the operating system scheduled the processes to run uh then gpu0 ended then
2:56:10
processes to run uh then gpu0 ended then
2:56:10
processes to run uh then gpu0 ended then GPU 3 and two and then uh probably
2:56:14
GPU 3 and two and then uh probably
2:56:14
GPU 3 and two and then uh probably process 5 or something like that has uh
2:56:17
process 5 or something like that has uh
2:56:17
process 5 or something like that has uh exited and and DDP really doesn't like
2:56:19
exited and and DDP really doesn't like
2:56:19
exited and and DDP really doesn't like that because we didn't properly dispose
2:56:21
that because we didn't properly dispose
2:56:21
that because we didn't properly dispose of uh the multi-gpus um setting and so
2:56:27
of uh the multi-gpus um setting and so
2:56:27
of uh the multi-gpus um setting and so process group has not been destroyed
2:56:28
process group has not been destroyed
2:56:28
process group has not been destroyed before we destruct uh so it really
2:56:31
before we destruct uh so it really
2:56:31
before we destruct uh so it really doesn't like that and in an actual
2:56:33
doesn't like that and in an actual
2:56:33
doesn't like that and in an actual application we would want to call
2:56:34
application we would want to call
2:56:34
application we would want to call destroy process group uh so that we
2:56:37
destroy process group uh so that we
2:56:37
destroy process group uh so that we clean up DDP properly and so it doesn't
2:56:40
clean up DDP properly and so it doesn't
2:56:40
clean up DDP properly and so it doesn't like that too much and then the rest of
2:56:41
like that too much and then the rest of
2:56:41
like that too much and then the rest of the gpus finish and that's it so
2:56:45
the gpus finish and that's it so
2:56:45
the gpus finish and that's it so basically we can't guarantee when these
2:56:46
basically we can't guarantee when these
2:56:46
basically we can't guarantee when these processes are running it's totally
2:56:48
processes are running it's totally
2:56:48
processes are running it's totally but they are running in parallel we
2:56:50
but they are running in parallel we
2:56:50
but they are running in parallel we don't want them to be printing um and
2:56:54
don't want them to be printing um and
2:56:54
don't want them to be printing um and next up let's erase
2:56:57
next up let's erase
2:56:57
next up let's erase this next up we want to make sure that
2:56:59
this next up we want to make sure that
2:56:59
this next up we want to make sure that when we create data loader light we need
2:57:01
when we create data loader light we need
2:57:01
when we create data loader light we need to now make it aware of this
2:57:03
to now make it aware of this
2:57:03
to now make it aware of this multi-process um setting because we
2:57:06
multi-process um setting because we
2:57:06
multi-process um setting because we don't want all the processes to be
2:57:07
don't want all the processes to be
2:57:07
don't want all the processes to be loading the exact same data we want
2:57:10
loading the exact same data we want
2:57:10
loading the exact same data we want every process to get its own chunk of
2:57:11
every process to get its own chunk of
2:57:11
every process to get its own chunk of data so that they're all working on
2:57:13
data so that they're all working on
2:57:13
data so that they're all working on different parts of the data set of
2:57:14
different parts of the data set of
2:57:14
different parts of the data set of course so let's adjust that so one
2:57:17
course so let's adjust that so one
2:57:17
course so let's adjust that so one particular particularly simple and a
2:57:19
particular particularly simple and a
2:57:19
particular particularly simple and a naive way to do this is we have to make
2:57:21
naive way to do this is we have to make
2:57:21
naive way to do this is we have to make sure that we pass in the rank and the
2:57:23
sure that we pass in the rank and the
2:57:23
sure that we pass in the rank and the size to the data
2:57:25
size to the data
2:57:25
size to the data loader and then when we come up here we
2:57:28
loader and then when we come up here we
2:57:28
loader and then when we come up here we see that we now take Rank and processes
2:57:29
see that we now take Rank and processes
2:57:29
see that we now take Rank and processes and we save them now the current
2:57:32
and we save them now the current
2:57:32
and we save them now the current position will not be zero uh because
2:57:35
position will not be zero uh because
2:57:35
position will not be zero uh because what we want is we want to stride out
2:57:37
what we want is we want to stride out
2:57:37
what we want is we want to stride out all the processes so one way to do this
2:57:40
all the processes so one way to do this
2:57:40
all the processes so one way to do this is we basically take S.B times salt. T
2:57:43
is we basically take S.B times salt. T
2:57:43
is we basically take S.B times salt. T and then multiply it by the process
2:57:46
and then multiply it by the process
2:57:46
and then multiply it by the process rank so proc process rank 0 will start
2:57:49
rank so proc process rank 0 will start
2:57:49
rank so proc process rank 0 will start at zero but process rank one now starts
2:57:52
at zero but process rank one now starts
2:57:52
at zero but process rank one now starts at B * T process rank two is starts at 2
2:57:55
at B * T process rank two is starts at 2
2:57:55
at B * T process rank two is starts at 2 * B * D Etc so that is the
2:57:58
* B * D Etc so that is the
2:57:59
* B * D Etc so that is the initialization now we still they still
2:58:01
initialization now we still they still
2:58:01
initialization now we still they still do this identically but now when we
2:58:04
do this identically but now when we
2:58:04
do this identically but now when we advance we don't Advance by B * T we
2:58:06
advance we don't Advance by B * T we
2:58:06
advance we don't Advance by B * T we advance by B * T times number of
2:58:10
advance by B * T times number of
2:58:10
advance by B * T times number of processes right so basically um the
2:58:14
processes right so basically um the
2:58:14
processes right so basically um the total number of tokens that we're um
2:58:16
total number of tokens that we're um
2:58:16
total number of tokens that we're um consuming is B * T * number processes
2:58:19
consuming is B * T * number processes
2:58:19
consuming is B * T * number processes and they all go off to a different Rank
2:58:23
and they all go off to a different Rank
2:58:23
and they all go off to a different Rank and the position has to advance by the
2:58:24
and the position has to advance by the
2:58:24
and the position has to advance by the entire
2:58:26
entire
2:58:26
entire chunk and then here B * T time uh s. num
2:58:30
chunk and then here B * T time uh s. num
2:58:30
chunk and then here B * T time uh s. num processes + one would be to exceed
2:58:33
processes + one would be to exceed
2:58:33
processes + one would be to exceed number of tokens then we're going to
2:58:35
number of tokens then we're going to
2:58:35
number of tokens then we're going to Loop and when we Loop we want to of
2:58:37
Loop and when we Loop we want to of
2:58:37
Loop and when we Loop we want to of course Loop in the exact same way so we
2:58:39
course Loop in the exact same way so we
2:58:39
course Loop in the exact same way so we sort of like reset back uh so this is
2:58:42
sort of like reset back uh so this is
2:58:42
sort of like reset back uh so this is the simplest change that I can uh find
2:58:45
the simplest change that I can uh find
2:58:45
the simplest change that I can uh find for kind of a very simple distributed
2:58:47
for kind of a very simple distributed
2:58:47
for kind of a very simple distributed data Lo light and um you can notice that
2:58:50
data Lo light and um you can notice that
2:58:50
data Lo light and um you can notice that if process rank is zero and non
2:58:52
if process rank is zero and non
2:58:52
if process rank is zero and non processes is one then uh the whole thing
2:58:54
processes is one then uh the whole thing
2:58:54
processes is one then uh the whole thing will be identical to what we had before
2:58:56
will be identical to what we had before
2:58:56
will be identical to what we had before but now we can have actually multiple
2:58:58
but now we can have actually multiple
2:58:58
but now we can have actually multiple processes uh running and this should
2:59:00
processes uh running and this should
2:59:00
processes uh running and this should work
2:59:01
work
2:59:01
work fine um so that's the data loader okay
2:59:05
fine um so that's the data loader okay
2:59:05
fine um so that's the data loader okay so next up once they've all initialized
2:59:07
so next up once they've all initialized
2:59:07
so next up once they've all initialized the data loader they come here and they
2:59:09
the data loader they come here and they
2:59:09
the data loader they come here and they all create a GPT model uh so we create
2:59:13
all create a GPT model uh so we create
2:59:13
all create a GPT model uh so we create eight GPT models on eight processes but
2:59:15
eight GPT models on eight processes but
2:59:15
eight GPT models on eight processes but because the seeds are fixed here they
2:59:17
because the seeds are fixed here they
2:59:17
because the seeds are fixed here they all create the same identical model they
2:59:20
all create the same identical model they
2:59:20
all create the same identical model they all move it to the device of their Rank
2:59:22
all move it to the device of their Rank
2:59:22
all move it to the device of their Rank and they all compile the model and
2:59:25
and they all compile the model and
2:59:25
and they all compile the model and because the models are identical there
2:59:26
because the models are identical there
2:59:26
because the models are identical there are eight identical compilations
2:59:28
are eight identical compilations
2:59:28
are eight identical compilations happening in parallel but that's okay
2:59:31
happening in parallel but that's okay
2:59:31
happening in parallel but that's okay now none of this uh changes because that
2:59:32
now none of this uh changes because that
2:59:33
now none of this uh changes because that is on a per step basis and we're
2:59:34
is on a per step basis and we're
2:59:34
is on a per step basis and we're currently working kind of within step
2:59:36
currently working kind of within step
2:59:36
currently working kind of within step because we need to um just uh all the
2:59:39
because we need to um just uh all the
2:59:39
because we need to um just uh all the all the changes we're making are kind of
2:59:41
all the changes we're making are kind of
2:59:41
all the changes we're making are kind of like a within step
2:59:42
like a within step
2:59:42
like a within step changes now the important thing here is
2:59:44
changes now the important thing here is
2:59:44
changes now the important thing here is when we construct the M model we
2:59:47
when we construct the M model we
2:59:47
when we construct the M model we actually have a bit of work to to do
2:59:48
actually have a bit of work to to do
2:59:48
actually have a bit of work to to do here get loits is deprecated so uh
2:59:50
here get loits is deprecated so uh
2:59:50
here get loits is deprecated so uh create
2:59:52
create
2:59:52
create model we need to actually wrap the model
2:59:55
model we need to actually wrap the model
2:59:55
model we need to actually wrap the model into the distributed data parallel
2:59:58
into the distributed data parallel
2:59:58
into the distributed data parallel container so um this is how we wrap the
3:00:01
container so um this is how we wrap the
3:00:01
container so um this is how we wrap the model into the DDP container and these
3:00:04
model into the DDP container and these
3:00:04
model into the DDP container and these are the docs for DDP and they're quite
3:00:07
are the docs for DDP and they're quite
3:00:07
are the docs for DDP and they're quite extensive and there's a lot of caveats
3:00:09
extensive and there's a lot of caveats
3:00:09
extensive and there's a lot of caveats and a lot of things to be careful with
3:00:10
and a lot of things to be careful with
3:00:10
and a lot of things to be careful with because everything complexifies times 10
3:00:12
because everything complexifies times 10
3:00:12
because everything complexifies times 10 when multiple processes are involved but
3:00:15
when multiple processes are involved but
3:00:15
when multiple processes are involved but roughly speaking this device IDs I
3:00:17
roughly speaking this device IDs I
3:00:17
roughly speaking this device IDs I believe has to be passed in now
3:00:18
believe has to be passed in now
3:00:18
believe has to be passed in now unfortunately the docs for what device
3:00:20
unfortunately the docs for what device
3:00:20
unfortunately the docs for what device IDs is is is extremely unclear uh so
3:00:24
IDs is is is extremely unclear uh so
3:00:24
IDs is is is extremely unclear uh so when you actually like come here this
3:00:26
when you actually like come here this
3:00:26
when you actually like come here this comment for what device IDs is is
3:00:29
comment for what device IDs is is
3:00:29
comment for what device IDs is is roughly
3:00:30
roughly
3:00:30
roughly nonsensical um but I'm pretty sure it's
3:00:33
nonsensical um but I'm pretty sure it's
3:00:33
nonsensical um but I'm pretty sure it's supposed to be the DDP local rank so not
3:00:35
supposed to be the DDP local rank so not
3:00:35
supposed to be the DDP local rank so not the DDP rank the local rank uh so this
3:00:39
the DDP rank the local rank uh so this
3:00:39
the DDP rank the local rank uh so this is what you pass in here this wraps the
3:00:41
is what you pass in here this wraps the
3:00:41
is what you pass in here this wraps the model and in particular what DDP does
3:00:43
model and in particular what DDP does
3:00:43
model and in particular what DDP does for you is in a forward pass it actually
3:00:45
for you is in a forward pass it actually
3:00:45
for you is in a forward pass it actually behaves identically so um my
3:00:48
behaves identically so um my
3:00:48
behaves identically so um my understanding of it is nothing should be
3:00:49
understanding of it is nothing should be
3:00:49
understanding of it is nothing should be changed in the forward pass but in the
3:00:51
changed in the forward pass but in the
3:00:51
changed in the forward pass but in the backward pass as you are doing the
3:00:53
backward pass as you are doing the
3:00:53
backward pass as you are doing the backward pass um in the simpl setting
3:00:56
backward pass um in the simpl setting
3:00:56
backward pass um in the simpl setting once the backp passes over on each
3:00:59
once the backp passes over on each
3:00:59
once the backp passes over on each independent GPU each independent GPU has
3:01:02
independent GPU each independent GPU has
3:01:02
independent GPU each independent GPU has the gradient for all the parameters and
3:01:05
the gradient for all the parameters and
3:01:05
the gradient for all the parameters and what DDP does for you is once the
3:01:06
what DDP does for you is once the
3:01:06
what DDP does for you is once the backward pass is over it will call
3:01:09
backward pass is over it will call
3:01:09
backward pass is over it will call what's called all reduce and it
3:01:11
what's called all reduce and it
3:01:11
what's called all reduce and it basically does an average across all the
3:01:14
basically does an average across all the
3:01:14
basically does an average across all the uh ranks of their gradients and and then
3:01:18
uh ranks of their gradients and and then
3:01:18
uh ranks of their gradients and and then it will deposit that average on every
3:01:20
it will deposit that average on every
3:01:20
it will deposit that average on every single rank so every sing Single rank
3:01:22
single rank so every sing Single rank
3:01:22
single rank so every sing Single rank will end up with the average on it and
3:01:25
will end up with the average on it and
3:01:25
will end up with the average on it and so basically that's the communication it
3:01:26
so basically that's the communication it
3:01:27
so basically that's the communication it just synchronizes and averages the
3:01:28
just synchronizes and averages the
3:01:28
just synchronizes and averages the gradients and that's what DDP offers you
3:01:30
gradients and that's what DDP offers you
3:01:31
gradients and that's what DDP offers you now DDP actually is a little bit more um
3:01:34
now DDP actually is a little bit more um
3:01:34
now DDP actually is a little bit more um it is a little bit more involved than
3:01:35
it is a little bit more involved than
3:01:35
it is a little bit more involved than that because as you are doing the
3:01:37
that because as you are doing the
3:01:37
that because as you are doing the backward pass through the layers of the
3:01:38
backward pass through the layers of the
3:01:38
backward pass through the layers of the Transformer it actually can dispatch
3:01:41
Transformer it actually can dispatch
3:01:41
Transformer it actually can dispatch Communications for the gradient while
3:01:43
Communications for the gradient while
3:01:43
Communications for the gradient while the backward pass is still happening so
3:01:45
the backward pass is still happening so
3:01:45
the backward pass is still happening so there's overlap of the uh communication
3:01:47
there's overlap of the uh communication
3:01:47
there's overlap of the uh communication of the gradient and the synchronization
3:01:48
of the gradient and the synchronization
3:01:48
of the gradient and the synchronization of them and uh the backward pass and uh
3:01:52
of them and uh the backward pass and uh
3:01:52
of them and uh the backward pass and uh this is just more efficient and um uh to
3:01:55
this is just more efficient and um uh to
3:01:55
this is just more efficient and um uh to do it that way so that's what DDP does
3:01:57
do it that way so that's what DDP does
3:01:57
do it that way so that's what DDP does for you um forward is unchanged and
3:02:00
for you um forward is unchanged and
3:02:00
for you um forward is unchanged and backward is mostly unchanged and we're
3:02:02
backward is mostly unchanged and we're
3:02:02
backward is mostly unchanged and we're tacking on this average as we'll see in
3:02:04
tacking on this average as we'll see in
3:02:04
tacking on this average as we'll see in a bit okay so now let's go to the uh
3:02:08
a bit okay so now let's go to the uh
3:02:08
a bit okay so now let's go to the uh optimization nothing here changes let's
3:02:11
optimization nothing here changes let's
3:02:11
optimization nothing here changes let's go to the optimization here the inner
3:02:12
go to the optimization here the inner
3:02:12
go to the optimization here the inner loop and think through the
3:02:13
loop and think through the
3:02:13
loop and think through the synchronization of uh these gradients in
3:02:15
synchronization of uh these gradients in
3:02:15
synchronization of uh these gradients in the DP so basically by default what
3:02:18
the DP so basically by default what
3:02:18
the DP so basically by default what happens as I mentioned is when you do l.
3:02:20
happens as I mentioned is when you do l.
3:02:20
happens as I mentioned is when you do l. backward here it will do the backward
3:02:22
backward here it will do the backward
3:02:22
backward here it will do the backward pass and then it will synchronize the
3:02:24
pass and then it will synchronize the
3:02:24
pass and then it will synchronize the gradients um the problem here is because
3:02:28
gradients um the problem here is because
3:02:28
gradients um the problem here is because of the gradient accumulation steps Loop
3:02:30
of the gradient accumulation steps Loop
3:02:30
of the gradient accumulation steps Loop here we don't actually want to do the
3:02:33
here we don't actually want to do the
3:02:33
here we don't actually want to do the synchronization after every single La
3:02:35
synchronization after every single La
3:02:35
synchronization after every single La step backward because we are just
3:02:37
step backward because we are just
3:02:37
step backward because we are just depositing gradients and we're doing
3:02:39
depositing gradients and we're doing
3:02:39
depositing gradients and we're doing that serially and we just want them
3:02:40
that serially and we just want them
3:02:40
that serially and we just want them adding up and we don't want to
3:02:42
adding up and we don't want to
3:02:42
adding up and we don't want to synchronize every single time that would
3:02:44
synchronize every single time that would
3:02:44
synchronize every single time that would be extremely wasteful so basically we
3:02:46
be extremely wasteful so basically we
3:02:46
be extremely wasteful so basically we want to add them up and then on the the
3:02:48
want to add them up and then on the the
3:02:48
want to add them up and then on the the very last uh it's only on the very last
3:02:50
very last uh it's only on the very last
3:02:50
very last uh it's only on the very last step when micro when micro step becomes
3:02:53
step when micro when micro step becomes
3:02:53
step when micro when micro step becomes gratak steps minus one only at that last
3:02:55
gratak steps minus one only at that last
3:02:55
gratak steps minus one only at that last step do we want to actually do the
3:02:58
step do we want to actually do the
3:02:58
step do we want to actually do the alberu uh to average up the gradients so
3:03:02
alberu uh to average up the gradients so
3:03:02
alberu uh to average up the gradients so to do that we come here and um the
3:03:05
to do that we come here and um the
3:03:05
to do that we come here and um the official sanctioned way by the way is to
3:03:06
official sanctioned way by the way is to
3:03:07
official sanctioned way by the way is to do this no sync context manager so
3:03:10
do this no sync context manager so
3:03:10
do this no sync context manager so pytorch says this is a context manager
3:03:12
pytorch says this is a context manager
3:03:13
pytorch says this is a context manager to disable gradient synchronization
3:03:14
to disable gradient synchronization
3:03:14
to disable gradient synchronization across DDP processes So within this
3:03:16
across DDP processes So within this
3:03:17
across DDP processes So within this context gradient will be
3:03:19
context gradient will be
3:03:19
context gradient will be accumulated and basically when you do no
3:03:21
accumulated and basically when you do no
3:03:21
accumulated and basically when you do no sync there will be no communication so
3:03:24
sync there will be no communication so
3:03:24
sync there will be no communication so they are telling us to do with DDP no
3:03:26
they are telling us to do with DDP no
3:03:26
they are telling us to do with DDP no sync uh do the gradient accumulation
3:03:29
sync uh do the gradient accumulation
3:03:29
sync uh do the gradient accumulation accumulate grats and then they are
3:03:30
accumulate grats and then they are
3:03:30
accumulate grats and then they are asking us to do DDP again with another
3:03:32
asking us to do DDP again with another
3:03:32
asking us to do DDP again with another input and that backward and I just
3:03:35
input and that backward and I just
3:03:35
input and that backward and I just really don't love this I I just really
3:03:37
really don't love this I I just really
3:03:37
really don't love this I I just really don't like it uh the fact that you have
3:03:39
don't like it uh the fact that you have
3:03:39
don't like it uh the fact that you have to copy paste your code here and use a
3:03:40
to copy paste your code here and use a
3:03:40
to copy paste your code here and use a context manager and this is just super
3:03:42
context manager and this is just super
3:03:42
context manager and this is just super ugly so when I went to this source code
3:03:45
ugly so when I went to this source code
3:03:45
ugly so when I went to this source code here you can see that when you enter
3:03:48
here you can see that when you enter
3:03:48
here you can see that when you enter you simply toggle this variable this
3:03:51
you simply toggle this variable this
3:03:51
you simply toggle this variable this require backward grat sync and this is
3:03:54
require backward grat sync and this is
3:03:54
require backward grat sync and this is uh being toggled around and changed and
3:03:58
uh being toggled around and changed and
3:03:58
uh being toggled around and changed and this is the variable that basically uh
3:04:01
this is the variable that basically uh
3:04:01
this is the variable that basically uh if you step through it is being toggled
3:04:03
if you step through it is being toggled
3:04:03
if you step through it is being toggled to determine if the gradient is going to
3:04:05
to determine if the gradient is going to
3:04:05
to determine if the gradient is going to be synchronized so I actually just kind
3:04:07
be synchronized so I actually just kind
3:04:07
be synchronized so I actually just kind of like to use that directly uh so
3:04:10
of like to use that directly uh so
3:04:10
of like to use that directly uh so instead what I like to do is the
3:04:13
instead what I like to do is the
3:04:13
instead what I like to do is the following right here before the L back
3:04:15
following right here before the L back
3:04:15
following right here before the L back backward if we are using the DDP then um
3:04:20
backward if we are using the DDP then um
3:04:20
backward if we are using the DDP then um then basically we only want to
3:04:22
then basically we only want to
3:04:23
then basically we only want to synchronize we only want this variable
3:04:25
synchronize we only want this variable
3:04:25
synchronize we only want this variable to be true when it is the final
3:04:28
to be true when it is the final
3:04:28
to be true when it is the final iteration in all the other iterations
3:04:31
iteration in all the other iterations
3:04:31
iteration in all the other iterations inside the micr steps we want to be
3:04:33
inside the micr steps we want to be
3:04:33
inside the micr steps we want to be false so I just toggle it like this so
3:04:36
false so I just toggle it like this so
3:04:36
false so I just toggle it like this so required backward graph sync should only
3:04:38
required backward graph sync should only
3:04:38
required backward graph sync should only turn on when the micro step is the last
3:04:41
turn on when the micro step is the last
3:04:41
turn on when the micro step is the last step and so I'm toggling this variable
3:04:44
step and so I'm toggling this variable
3:04:44
step and so I'm toggling this variable directly and I hope that that impacts
3:04:46
directly and I hope that that impacts
3:04:47
directly and I hope that that impacts last St backwards
3:04:48
last St backwards
3:04:48
last St backwards and this is a naughty thing to do
3:04:49
and this is a naughty thing to do
3:04:49
and this is a naughty thing to do because you know they could probably
3:04:51
because you know they could probably
3:04:51
because you know they could probably change the DDP and this variable will go
3:04:53
change the DDP and this variable will go
3:04:53
change the DDP and this variable will go away but for now I believe this this
3:04:54
away but for now I believe this this
3:04:55
away but for now I believe this this works and it allows me to avoid the use
3:04:57
works and it allows me to avoid the use
3:04:57
works and it allows me to avoid the use of context managers and code duplication
3:04:59
of context managers and code duplication
3:05:00
of context managers and code duplication I'm just toggling the variable and then
3:05:01
I'm just toggling the variable and then
3:05:01
I'm just toggling the variable and then Lop backward will not synchronize most
3:05:03
Lop backward will not synchronize most
3:05:03
Lop backward will not synchronize most of the steps and it will synchronize the
3:05:04
of the steps and it will synchronize the
3:05:04
of the steps and it will synchronize the very last step and so once this is over
3:05:08
very last step and so once this is over
3:05:08
very last step and so once this is over uh and we come out every single um rank
3:05:13
uh and we come out every single um rank
3:05:13
uh and we come out every single um rank will suddenly magically have the average
3:05:17
will suddenly magically have the average
3:05:17
will suddenly magically have the average of all the gradients that were stored on
3:05:20
of all the gradients that were stored on
3:05:20
of all the gradients that were stored on all the ranks so now we have to think
3:05:22
all the ranks so now we have to think
3:05:22
all the ranks so now we have to think through whether that is what we want and
3:05:24
through whether that is what we want and
3:05:24
through whether that is what we want and also um if this suffices and whether how
3:05:29
also um if this suffices and whether how
3:05:29
also um if this suffices and whether how it works with the loss and what is loss
3:05:31
it works with the loss and what is loss
3:05:31
it works with the loss and what is loss AUM so let's think through through that
3:05:32
AUM so let's think through through that
3:05:33
AUM so let's think through through that now and the problem I'm getting at is
3:05:35
now and the problem I'm getting at is
3:05:35
now and the problem I'm getting at is that we've averaged the gradients which
3:05:37
that we've averaged the gradients which
3:05:37
that we've averaged the gradients which is great but the loss AUM has not been
3:05:40
is great but the loss AUM has not been
3:05:40
is great but the loss AUM has not been impacted yet and the and this is outside
3:05:43
impacted yet and the and this is outside
3:05:43
impacted yet and the and this is outside of the DDP container so that is not
3:05:45
of the DDP container so that is not
3:05:45
of the DDP container so that is not being averaged um and so here when when
3:05:47
being averaged um and so here when when
3:05:47
being averaged um and so here when when we are printing Los AUM well presumably
3:05:49
we are printing Los AUM well presumably
3:05:49
we are printing Los AUM well presumably we're only going to be printing on the
3:05:51
we're only going to be printing on the
3:05:51
we're only going to be printing on the master process uh rank zero and it's
3:05:53
master process uh rank zero and it's
3:05:53
master process uh rank zero and it's just going to be printing the losses
3:05:55
just going to be printing the losses
3:05:55
just going to be printing the losses that it saw on its process but instead
3:05:57
that it saw on its process but instead
3:05:57
that it saw on its process but instead we want it to print the loss over all
3:06:00
we want it to print the loss over all
3:06:00
we want it to print the loss over all the processes and the average of that
3:06:02
the processes and the average of that
3:06:02
the processes and the average of that loss because we did average of gradients
3:06:04
loss because we did average of gradients
3:06:04
loss because we did average of gradients so we want the average of loss as well
3:06:06
so we want the average of loss as well
3:06:06
so we want the average of loss as well so simply here after this uh this is the
3:06:09
so simply here after this uh this is the
3:06:09
so simply here after this uh this is the code that I've used in the past um and
3:06:13
code that I've used in the past um and
3:06:13
code that I've used in the past um and instead of LF we want
3:06:15
instead of LF we want
3:06:15
instead of LF we want Lum so if
3:06:18
Lum so if
3:06:18
Lum so if DDP again then this is a p torch
3:06:21
DDP again then this is a p torch
3:06:22
DDP again then this is a p torch distributed I import it where do I
3:06:24
distributed I import it where do I
3:06:24
distributed I import it where do I import
3:06:26
import
3:06:26
import it uh oh gosh so this file is starting
3:06:30
it uh oh gosh so this file is starting
3:06:30
it uh oh gosh so this file is starting to get out of control huh so if uh so
3:06:33
to get out of control huh so if uh so
3:06:33
to get out of control huh so if uh so import torch. distributed as dist
3:06:36
import torch. distributed as dist
3:06:36
import torch. distributed as dist so dist.
3:06:38
so dist.
3:06:38
so dist. ALU and we're doing the average on Lum
3:06:42
ALU and we're doing the average on Lum
3:06:42
ALU and we're doing the average on Lum and so this lakum tensor exists on all
3:06:44
and so this lakum tensor exists on all
3:06:44
and so this lakum tensor exists on all the ranks when we call all use of
3:06:46
the ranks when we call all use of
3:06:46
the ranks when we call all use of average it creates the average of those
3:06:48
average it creates the average of those
3:06:48
average it creates the average of those numbers and it deposits that average on
3:06:50
numbers and it deposits that average on
3:06:51
numbers and it deposits that average on all the ranks so all the ranks after
3:06:53
all the ranks so all the ranks after
3:06:53
all the ranks so all the ranks after this um call will now contain L AUM uh
3:06:57
this um call will now contain L AUM uh
3:06:57
this um call will now contain L AUM uh averaged up and so when we print here on
3:07:00
averaged up and so when we print here on
3:07:00
averaged up and so when we print here on the master process the L AUM is
3:07:01
the master process the L AUM is
3:07:02
the master process the L AUM is identical in all the other ranks as well
3:07:03
identical in all the other ranks as well
3:07:04
identical in all the other ranks as well so here if Master process
3:07:07
so here if Master process
3:07:07
so here if Master process oops we want to print like this okay and
3:07:10
oops we want to print like this okay and
3:07:10
oops we want to print like this okay and finally we have to be careful because
3:07:12
finally we have to be careful because
3:07:12
finally we have to be careful because we're not processing even more tokens so
3:07:15
we're not processing even more tokens so
3:07:15
we're not processing even more tokens so times DDP World size
3:07:18
times DDP World size
3:07:18
times DDP World size that's number of tokens that we've
3:07:19
that's number of tokens that we've
3:07:19
that's number of tokens that we've processed up
3:07:21
processed up
3:07:21
processed up above
3:07:24
and everything else should be fine uh
3:07:27
and everything else should be fine uh
3:07:27
and everything else should be fine uh the only other thing to be careful with
3:07:29
the only other thing to be careful with
3:07:29
the only other thing to be careful with is as I mentioned you want to destroy
3:07:31
is as I mentioned you want to destroy
3:07:31
is as I mentioned you want to destroy the process group so that we are nice to
3:07:33
the process group so that we are nice to
3:07:33
the process group so that we are nice to nickel and it's not going to uh to uh to
3:07:35
nickel and it's not going to uh to uh to
3:07:35
nickel and it's not going to uh to uh to DDP and it's not going to complain to us
3:07:38
DDP and it's not going to complain to us
3:07:38
DDP and it's not going to complain to us uh when we exit
3:07:40
uh when we exit
3:07:40
uh when we exit here so that should be it let's try to
3:07:43
here so that should be it let's try to
3:07:43
here so that should be it let's try to take it for a spin okay so I launched
3:07:44
take it for a spin okay so I launched
3:07:44
take it for a spin okay so I launched the script and it should be uh printing
3:07:46
the script and it should be uh printing
3:07:46
the script and it should be uh printing here imminently we're now training with
3:07:48
here imminently we're now training with
3:07:48
here imminently we're now training with 8 gpus at the same time so the gradient
3:07:51
8 gpus at the same time so the gradient
3:07:51
8 gpus at the same time so the gradient accumulation steps is not 32 it is now
3:07:53
accumulation steps is not 32 it is now
3:07:53
accumulation steps is not 32 it is now divide 8 and it's just four uh so um
3:07:57
divide 8 and it's just four uh so um
3:07:58
divide 8 and it's just four uh so um otherwise this is what the optimization
3:07:59
otherwise this is what the optimization
3:07:59
otherwise this is what the optimization now looks like and wow we're going
3:08:01
now looks like and wow we're going
3:08:01
now looks like and wow we're going really fast so we're processing 1.5
3:08:04
really fast so we're processing 1.5
3:08:04
really fast so we're processing 1.5 million tokens uh per second now so
3:08:09
million tokens uh per second now so
3:08:09
million tokens uh per second now so these are some serious numbers and the
3:08:11
these are some serious numbers and the
3:08:11
these are some serious numbers and the tiny shakespare data set is so tiny that
3:08:12
tiny shakespare data set is so tiny that
3:08:12
tiny shakespare data set is so tiny that we're just doing like so many Epoch over
3:08:15
we're just doing like so many Epoch over
3:08:15
we're just doing like so many Epoch over it most likely but this is roughly what
3:08:17
it most likely but this is roughly what
3:08:17
it most likely but this is roughly what looks like um one thing that I had to
3:08:20
looks like um one thing that I had to
3:08:20
looks like um one thing that I had to fix by the way is that this was model.
3:08:23
fix by the way is that this was model.
3:08:23
fix by the way is that this was model. configure optimizers which Now doesn't
3:08:25
configure optimizers which Now doesn't
3:08:25
configure optimizers which Now doesn't work because model now is a DDP model so
3:08:27
work because model now is a DDP model so
3:08:27
work because model now is a DDP model so instead this has to become raw
3:08:29
instead this has to become raw
3:08:29
instead this has to become raw model. configure optimizers where raw
3:08:32
model. configure optimizers where raw
3:08:32
model. configure optimizers where raw model is something I create here so
3:08:35
model is something I create here so
3:08:35
model is something I create here so right after I wrap the model into DDP uh
3:08:38
right after I wrap the model into DDP uh
3:08:38
right after I wrap the model into DDP uh I have to create the raw model which in
3:08:40
I have to create the raw model which in
3:08:40
I have to create the raw model which in the case of DDP is a model. module is
3:08:43
the case of DDP is a model. module is
3:08:43
the case of DDP is a model. module is where it stores the raw and then module
3:08:46
where it stores the raw and then module
3:08:46
where it stores the raw and then module of gpt2 as we have it which contains the
3:08:49
of gpt2 as we have it which contains the
3:08:49
of gpt2 as we have it which contains the uh configure optimizers function that we
3:08:51
uh configure optimizers function that we
3:08:51
uh configure optimizers function that we want to call so that's one thing that I
3:08:53
want to call so that's one thing that I
3:08:53
want to call so that's one thing that I have to fix otherwise this seems to run
3:08:56
have to fix otherwise this seems to run
3:08:56
have to fix otherwise this seems to run now one thing you'll notice is that when
3:08:57
now one thing you'll notice is that when
3:08:57
now one thing you'll notice is that when you actually compare this run and the
3:08:59
you actually compare this run and the
3:08:59
you actually compare this run and the numbers in it to the just running a
3:09:01
numbers in it to the just running a
3:09:01
numbers in it to the just running a single GPU you'll notice that this is
3:09:03
single GPU you'll notice that this is
3:09:04
single GPU you'll notice that this is single GPU run with 32 gratum the
3:09:06
single GPU run with 32 gratum the
3:09:06
single GPU run with 32 gratum the numbers won't exactly match
3:09:09
numbers won't exactly match
3:09:09
numbers won't exactly match up and uh that's kind of a boring reason
3:09:11
up and uh that's kind of a boring reason
3:09:11
up and uh that's kind of a boring reason for why that happens uh the reason for
3:09:13
for why that happens uh the reason for
3:09:13
for why that happens uh the reason for that is that in the data loader we're
3:09:15
that is that in the data loader we're
3:09:15
that is that in the data loader we're basically just iterating through batches
3:09:17
basically just iterating through batches
3:09:17
basically just iterating through batches and slightly different way because now
3:09:18
and slightly different way because now
3:09:18
and slightly different way because now we're looking for an entire page of data
3:09:21
we're looking for an entire page of data
3:09:21
we're looking for an entire page of data and if that page uh for all the gpus if
3:09:24
and if that page uh for all the gpus if
3:09:24
and if that page uh for all the gpus if that chunk exceeds the number of tokens
3:09:26
that chunk exceeds the number of tokens
3:09:26
that chunk exceeds the number of tokens we just Loop and so actually the single
3:09:29
we just Loop and so actually the single
3:09:29
we just Loop and so actually the single GPU and the H GPU process will end up um
3:09:33
GPU and the H GPU process will end up um
3:09:33
GPU and the H GPU process will end up um resetting in a slightly different Manner
3:09:35
resetting in a slightly different Manner
3:09:35
resetting in a slightly different Manner and so our batches are slightly
3:09:36
and so our batches are slightly
3:09:36
and so our batches are slightly different and so we get slightly
3:09:38
different and so we get slightly
3:09:38
different and so we get slightly different numbers but one way to
3:09:39
different numbers but one way to
3:09:39
different numbers but one way to convince yourself that this is okay it
3:09:42
convince yourself that this is okay it
3:09:42
convince yourself that this is okay it just make the total batch size much
3:09:43
just make the total batch size much
3:09:43
just make the total batch size much smaller and the b and a t and then um
3:09:48
smaller and the b and a t and then um
3:09:48
smaller and the b and a t and then um so I think I used uh 4 * 124 * 8 so I
3:09:52
so I think I used uh 4 * 124 * 8 so I
3:09:52
so I think I used uh 4 * 124 * 8 so I used 32768 as a total patch size and
3:09:55
used 32768 as a total patch size and
3:09:55
used 32768 as a total patch size and then um so I made sure that the single
3:09:57
then um so I made sure that the single
3:09:57
then um so I made sure that the single GPU will do eight creting accumulation
3:09:59
GPU will do eight creting accumulation
3:10:00
GPU will do eight creting accumulation steps and then the multi-gpu and then
3:10:02
steps and then the multi-gpu and then
3:10:02
steps and then the multi-gpu and then you're reducing the boundary effects of
3:10:04
you're reducing the boundary effects of
3:10:04
you're reducing the boundary effects of the data loader and you'll see that the
3:10:06
the data loader and you'll see that the
3:10:06
the data loader and you'll see that the numbers match up so long story short
3:10:08
numbers match up so long story short
3:10:08
numbers match up so long story short we're now going really really fast the
3:10:10
we're now going really really fast the
3:10:10
we're now going really really fast the optimization is mostly consistent with
3:10:12
optimization is mostly consistent with
3:10:12
optimization is mostly consistent with gpt2 and three hyper parameters and uh
3:10:16
gpt2 and three hyper parameters and uh
3:10:16
gpt2 and three hyper parameters and uh we have outgrown our tiny Shakespeare
3:10:18
we have outgrown our tiny Shakespeare
3:10:18
we have outgrown our tiny Shakespeare file and we want to upgrade it so let's
3:10:20
file and we want to upgrade it so let's
3:10:20
file and we want to upgrade it so let's move to next to that next so let's now
3:10:22
move to next to that next so let's now
3:10:22
move to next to that next so let's now take a look at what data sets were used
3:10:23
take a look at what data sets were used
3:10:23
take a look at what data sets were used by gpt2 and gpt3 so gbt2 used this web
3:10:27
by gpt2 and gpt3 so gbt2 used this web
3:10:27
by gpt2 and gpt3 so gbt2 used this web Text data set that was never released um
3:10:30
Text data set that was never released um
3:10:30
Text data set that was never released um there's an attempt at reproducing it
3:10:32
there's an attempt at reproducing it
3:10:32
there's an attempt at reproducing it called open web text uh so basically
3:10:34
called open web text uh so basically
3:10:34
called open web text uh so basically roughly speaking what they say here in
3:10:35
roughly speaking what they say here in
3:10:35
roughly speaking what they say here in the paper is that they scraped all
3:10:37
the paper is that they scraped all
3:10:37
the paper is that they scraped all outbound links from Reddit and then uh
3:10:41
outbound links from Reddit and then uh
3:10:41
outbound links from Reddit and then uh with at least three Karma and that was
3:10:43
with at least three Karma and that was
3:10:43
with at least three Karma and that was kind of like their starting point and
3:10:44
kind of like their starting point and
3:10:44
kind of like their starting point and they collected all the web P all the web
3:10:45
they collected all the web P all the web
3:10:45
they collected all the web P all the web pages and all the text in them and so
3:10:48
pages and all the text in them and so
3:10:48
pages and all the text in them and so this was 45 million links and this ended
3:10:50
this was 45 million links and this ended
3:10:50
this was 45 million links and this ended up being 40 GB of text so uh so that's
3:10:54
up being 40 GB of text so uh so that's
3:10:54
up being 40 GB of text so uh so that's roughly what gpt2 says about its data
3:10:57
roughly what gpt2 says about its data
3:10:57
roughly what gpt2 says about its data set so it's basically outbound links
3:10:58
set so it's basically outbound links
3:10:58
set so it's basically outbound links from Reddit now when we go over to gpt3
3:11:01
from Reddit now when we go over to gpt3
3:11:01
from Reddit now when we go over to gpt3 there's a training data set section and
3:11:03
there's a training data set section and
3:11:03
there's a training data set section and that's where they start to talk about um
3:11:05
that's where they start to talk about um
3:11:05
that's where they start to talk about um common coll which is a lot more uh used
3:11:09
common coll which is a lot more uh used
3:11:09
common coll which is a lot more uh used actually I think even gpt2 talked about
3:11:11
actually I think even gpt2 talked about
3:11:11
actually I think even gpt2 talked about common coll um but basically it's not a
3:11:14
common coll um but basically it's not a
3:11:14
common coll um but basically it's not a very high quality data set all by itself
3:11:16
very high quality data set all by itself
3:11:16
very high quality data set all by itself because it is extremely noisy this is a
3:11:18
because it is extremely noisy this is a
3:11:18
because it is extremely noisy this is a completely random subset of the internet
3:11:20
completely random subset of the internet
3:11:20
completely random subset of the internet and it's much worse than you think so
3:11:22
and it's much worse than you think so
3:11:22
and it's much worse than you think so people go into Great Lengths to filter
3:11:24
people go into Great Lengths to filter
3:11:24
people go into Great Lengths to filter common craw because there's good stuff
3:11:26
common craw because there's good stuff
3:11:26
common craw because there's good stuff in it but most of it is just like ad
3:11:27
in it but most of it is just like ad
3:11:27
in it but most of it is just like ad spam random tables and numbers and stock
3:11:30
spam random tables and numbers and stock
3:11:30
spam random tables and numbers and stock tickers and uh it's just total mess
3:11:35
tickers and uh it's just total mess
3:11:35
tickers and uh it's just total mess so that's why people like to train on
3:11:38
so that's why people like to train on
3:11:38
so that's why people like to train on these data mixtures that they curate and
3:11:41
these data mixtures that they curate and
3:11:41
these data mixtures that they curate and uh are careful with so a large chunk of
3:11:44
uh are careful with so a large chunk of
3:11:44
uh are careful with so a large chunk of these data mixtures typically will be
3:11:45
these data mixtures typically will be
3:11:45
these data mixtures typically will be common C like for example 50% of the
3:11:47
common C like for example 50% of the
3:11:47
common C like for example 50% of the tokens will be comic but then here in
3:11:50
tokens will be comic but then here in
3:11:50
tokens will be comic but then here in gpt3 they're also using web text to from
3:11:52
gpt3 they're also using web text to from
3:11:52
gpt3 they're also using web text to from before so that's Reddit outbound but
3:11:54
before so that's Reddit outbound but
3:11:54
before so that's Reddit outbound but they're also adding for example books
3:11:56
they're also adding for example books
3:11:56
they're also adding for example books and they're adding Wikipedia there's
3:11:58
and they're adding Wikipedia there's
3:11:58
and they're adding Wikipedia there's many other things you can decide to add
3:12:00
many other things you can decide to add
3:12:00
many other things you can decide to add now this data set for gpt3 was also
3:12:02
now this data set for gpt3 was also
3:12:02
now this data set for gpt3 was also never released so today some of the data
3:12:05
never released so today some of the data
3:12:05
never released so today some of the data sets that I'm familiar with that are
3:12:06
sets that I'm familiar with that are
3:12:06
sets that I'm familiar with that are quite good and would be representative
3:12:08
quite good and would be representative
3:12:08
quite good and would be representative of something along these lines are
3:12:10
of something along these lines are
3:12:10
of something along these lines are number one the red pajama data set or
3:12:12
number one the red pajama data set or
3:12:12
number one the red pajama data set or more specifically for example the slim
3:12:14
more specifically for example the slim
3:12:14
more specifically for example the slim pajama subset of the red pajama data set
3:12:17
pajama subset of the red pajama data set
3:12:17
pajama subset of the red pajama data set which is a cleaned and D duplicated
3:12:19
which is a cleaned and D duplicated
3:12:19
which is a cleaned and D duplicated version of it and just to give you a
3:12:21
version of it and just to give you a
3:12:21
version of it and just to give you a sense again it's a bunch of common crawl
3:12:24
sense again it's a bunch of common crawl
3:12:24
sense again it's a bunch of common crawl um C4 which is also as far as I know
3:12:27
um C4 which is also as far as I know
3:12:27
um C4 which is also as far as I know more common craw but processed
3:12:28
more common craw but processed
3:12:28
more common craw but processed differently and then we have GitHub
3:12:30
differently and then we have GitHub
3:12:30
differently and then we have GitHub books archive Wikipedia stack exchange
3:12:33
books archive Wikipedia stack exchange
3:12:33
books archive Wikipedia stack exchange these are the kinds of data sets that
3:12:35
these are the kinds of data sets that
3:12:35
these are the kinds of data sets that would go into these data mixtures now
3:12:37
would go into these data mixtures now
3:12:37
would go into these data mixtures now specifically the one that I like that
3:12:38
specifically the one that I like that
3:12:38
specifically the one that I like that came out recently is called Fine web
3:12:41
came out recently is called Fine web
3:12:41
came out recently is called Fine web data set uh so this is an attempt to
3:12:43
data set uh so this is an attempt to
3:12:43
data set uh so this is an attempt to basically collect really high quality
3:12:45
basically collect really high quality
3:12:45
basically collect really high quality common coll data and filter it in this
3:12:48
common coll data and filter it in this
3:12:48
common coll data and filter it in this case to 15 trillion tokens and then in
3:12:51
case to 15 trillion tokens and then in
3:12:51
case to 15 trillion tokens and then in addition to that more recently
3:12:52
addition to that more recently
3:12:52
addition to that more recently huggingface released this fine web edu
3:12:55
huggingface released this fine web edu
3:12:55
huggingface released this fine web edu subset which is 1.3 trillion of
3:12:58
subset which is 1.3 trillion of
3:12:58
subset which is 1.3 trillion of educational and 5.4 trillion of high
3:13:01
educational and 5.4 trillion of high
3:13:01
educational and 5.4 trillion of high educational content so basically they're
3:13:03
educational content so basically they're
3:13:03
educational content so basically they're trying to filter common C to very high
3:13:06
trying to filter common C to very high
3:13:06
trying to filter common C to very high quality educational subsets and uh this
3:13:09
quality educational subsets and uh this
3:13:09
quality educational subsets and uh this is the one that we will use there's a
3:13:11
is the one that we will use there's a
3:13:11
is the one that we will use there's a long uh web page here on fine web and
3:13:14
long uh web page here on fine web and
3:13:14
long uh web page here on fine web and they go into a ton of detail about how
3:13:16
they go into a ton of detail about how
3:13:16
they go into a ton of detail about how they process the data which is really
3:13:17
they process the data which is really
3:13:17
they process the data which is really fascinating reading by the way and I
3:13:19
fascinating reading by the way and I
3:13:19
fascinating reading by the way and I would definitely recommend if you're
3:13:20
would definitely recommend if you're
3:13:20
would definitely recommend if you're interested into Data mixtures and so on
3:13:22
interested into Data mixtures and so on
3:13:22
interested into Data mixtures and so on and how data gets processed at these
3:13:24
and how data gets processed at these
3:13:24
and how data gets processed at these scales a look at this uh page and more
3:13:27
scales a look at this uh page and more
3:13:27
scales a look at this uh page and more specifically we'll be working with the
3:13:28
specifically we'll be working with the
3:13:28
specifically we'll be working with the fine web edu I think and it's basically
3:13:32
fine web edu I think and it's basically
3:13:32
fine web edu I think and it's basically educational content from the
3:13:34
educational content from the
3:13:34
educational content from the internet uh they show that training on
3:13:36
internet uh they show that training on
3:13:36
internet uh they show that training on educational content in in their metrics
3:13:39
educational content in in their metrics
3:13:39
educational content in in their metrics um uh works really really well and we're
3:13:43
um uh works really really well and we're
3:13:43
um uh works really really well and we're going to use this sample 10 billion
3:13:46
going to use this sample 10 billion
3:13:46
going to use this sample 10 billion tokens subsample of it because we're not
3:13:49
tokens subsample of it because we're not
3:13:49
tokens subsample of it because we're not going to be training on trillions of
3:13:50
going to be training on trillions of
3:13:50
going to be training on trillions of tokens uh we're just going to train on
3:13:52
tokens uh we're just going to train on
3:13:52
tokens uh we're just going to train on uh 10 billion sample of the fine web edu
3:13:56
uh 10 billion sample of the fine web edu
3:13:56
uh 10 billion sample of the fine web edu because empirically in my previous few
3:13:58
because empirically in my previous few
3:13:58
because empirically in my previous few experiments this actually suffices to
3:14:00
experiments this actually suffices to
3:14:00
experiments this actually suffices to really get close to gpt2 Performance and
3:14:02
really get close to gpt2 Performance and
3:14:02
really get close to gpt2 Performance and it's um simple enough to work with and
3:14:04
it's um simple enough to work with and
3:14:04
it's um simple enough to work with and so let's work with the sample 10 uh BT
3:14:07
so let's work with the sample 10 uh BT
3:14:07
so let's work with the sample 10 uh BT so our goal will be to download it
3:14:10
so our goal will be to download it
3:14:10
so our goal will be to download it process it and make sure that our data
3:14:12
process it and make sure that our data
3:14:12
process it and make sure that our data loader can work with it so let's get to
3:14:15
loader can work with it so let's get to
3:14:15
loader can work with it so let's get to that okay so I introduced another um
3:14:18
that okay so I introduced another um
3:14:18
that okay so I introduced another um file here that will basically download
3:14:21
file here that will basically download
3:14:21
file here that will basically download Fine web edu from huging face data sets
3:14:23
Fine web edu from huging face data sets
3:14:24
Fine web edu from huging face data sets it will pre-process and pre- tokenize
3:14:26
it will pre-process and pre- tokenize
3:14:26
it will pre-process and pre- tokenize all of the data and it will save data
3:14:28
all of the data and it will save data
3:14:28
all of the data and it will save data shards to a uh folder on um local disk
3:14:34
shards to a uh folder on um local disk
3:14:34
shards to a uh folder on um local disk and so while this is running uh just
3:14:38
and so while this is running uh just
3:14:38
and so while this is running uh just wanted to briefly mention that you can
3:14:40
wanted to briefly mention that you can
3:14:40
wanted to briefly mention that you can kind of look through the data set viewer
3:14:41
kind of look through the data set viewer
3:14:41
kind of look through the data set viewer here just to get a sense of what's in
3:14:42
here just to get a sense of what's in
3:14:43
here just to get a sense of what's in here and it's kind of interesting I mean
3:14:45
here and it's kind of interesting I mean
3:14:45
here and it's kind of interesting I mean it's a it basically looks like it's
3:14:47
it's a it basically looks like it's
3:14:47
it's a it basically looks like it's working fairly well like it's talking
3:14:48
working fairly well like it's talking
3:14:48
working fairly well like it's talking about nuclear energy in France it's
3:14:51
about nuclear energy in France it's
3:14:51
about nuclear energy in France it's talking
3:14:52
talking
3:14:52
talking about Mexican
3:14:54
about Mexican
3:14:54
about Mexican America some mac PJs Etc so actually it
3:14:58
America some mac PJs Etc so actually it
3:14:58
America some mac PJs Etc so actually it seems like their filters are working
3:14:59
seems like their filters are working
3:14:59
seems like their filters are working pretty well uh the filters here by the
3:15:01
pretty well uh the filters here by the
3:15:01
pretty well uh the filters here by the way were applied automatically using um
3:15:04
way were applied automatically using um
3:15:04
way were applied automatically using um llama 370b I believe and so uh basically
3:15:08
llama 370b I believe and so uh basically
3:15:08
llama 370b I believe and so uh basically llms are judging which content is
3:15:10
llms are judging which content is
3:15:10
llms are judging which content is educational and that ends up making it
3:15:11
educational and that ends up making it
3:15:11
educational and that ends up making it through the filter uh so that's pretty
3:15:13
through the filter uh so that's pretty
3:15:13
through the filter uh so that's pretty cool now in terms of the script itself
3:15:16
cool now in terms of the script itself
3:15:16
cool now in terms of the script itself I'm not going to go through the full
3:15:17
I'm not going to go through the full
3:15:17
I'm not going to go through the full script because it's not as interesting
3:15:19
script because it's not as interesting
3:15:19
script because it's not as interesting and not as llm Centric but when you run
3:15:22
and not as llm Centric but when you run
3:15:22
and not as llm Centric but when you run this basically number one we're going to
3:15:24
this basically number one we're going to
3:15:24
this basically number one we're going to load the data set uh which this is all
3:15:26
load the data set uh which this is all
3:15:26
load the data set uh which this is all huging face code running this you're
3:15:28
huging face code running this you're
3:15:28
huging face code running this you're going to need to uh pip install data
3:15:31
going to need to uh pip install data
3:15:31
going to need to uh pip install data sets um so it's downloading the data set
3:15:35
sets um so it's downloading the data set
3:15:35
sets um so it's downloading the data set then it is tokenizing all of the
3:15:37
then it is tokenizing all of the
3:15:37
then it is tokenizing all of the documents inside this data set now when
3:15:39
documents inside this data set now when
3:15:39
documents inside this data set now when we tokenize the documents you'll notice
3:15:42
we tokenize the documents you'll notice
3:15:42
we tokenize the documents you'll notice that um to tokenize a single document uh
3:15:46
that um to tokenize a single document uh
3:15:46
that um to tokenize a single document uh we first
3:15:47
we first
3:15:47
we first start the tokens with the end of text
3:15:49
start the tokens with the end of text
3:15:49
start the tokens with the end of text token and this is a special token in the
3:15:51
token and this is a special token in the
3:15:51
token and this is a special token in the gpt2 tokenizer as you know so
3:15:54
gpt2 tokenizer as you know so
3:15:54
gpt2 tokenizer as you know so 50256 is the ID of the end of text and
3:15:57
50256 is the ID of the end of text and
3:15:57
50256 is the ID of the end of text and this is what begins a document even
3:15:59
this is what begins a document even
3:15:59
this is what begins a document even though it's called end of text but this
3:16:01
though it's called end of text but this
3:16:01
though it's called end of text but this is uh the first token that begins a
3:16:03
is uh the first token that begins a
3:16:03
is uh the first token that begins a document then we extend with all of the
3:16:06
document then we extend with all of the
3:16:06
document then we extend with all of the tokens of that document then we create a
3:16:08
tokens of that document then we create a
3:16:08
tokens of that document then we create a numpy array out of that we make sure
3:16:11
numpy array out of that we make sure
3:16:11
numpy array out of that we make sure that all the tokens are between
3:16:14
that all the tokens are between
3:16:14
that all the tokens are between oh okay let me debug this
3:16:17
oh okay let me debug this
3:16:17
oh okay let me debug this okay so apologies for that uh it just
3:16:19
okay so apologies for that uh it just
3:16:19
okay so apologies for that uh it just had to do with me using a float division
3:16:21
had to do with me using a float division
3:16:21
had to do with me using a float division in Python it must be integer division so
3:16:23
in Python it must be integer division so
3:16:23
in Python it must be integer division so that this is an INT and everything is
3:16:25
that this is an INT and everything is
3:16:25
that this is an INT and everything is nice um okay but basically the
3:16:28
nice um okay but basically the
3:16:28
nice um okay but basically the tokenization here is relatively
3:16:29
tokenization here is relatively
3:16:29
tokenization here is relatively straightforward returns tokens in mp.
3:16:32
straightforward returns tokens in mp.
3:16:32
straightforward returns tokens in mp. un6 uh we're using .16 to save a little
3:16:35
un6 uh we're using .16 to save a little
3:16:35
un6 uh we're using .16 to save a little bit of space because 2 to the 16us 1 is
3:16:39
bit of space because 2 to the 16us 1 is
3:16:39
bit of space because 2 to the 16us 1 is 65,000 so the gpt2 max token ID is well
3:16:43
65,000 so the gpt2 max token ID is well
3:16:43
65,000 so the gpt2 max token ID is well below that and then here there's a bunch
3:16:45
below that and then here there's a bunch
3:16:45
below that and then here there's a bunch of multiprocessing code and it's
3:16:47
of multiprocessing code and it's
3:16:47
of multiprocessing code and it's honestly not that exciting so I'm not
3:16:48
honestly not that exciting so I'm not
3:16:48
honestly not that exciting so I'm not going to step through it but we're
3:16:50
going to step through it but we're
3:16:50
going to step through it but we're loading the data set we're tokenizing it
3:16:52
loading the data set we're tokenizing it
3:16:52
loading the data set we're tokenizing it and we're saving everything to shards
3:16:55
and we're saving everything to shards
3:16:55
and we're saving everything to shards and the shards are numpy files uh so
3:16:58
and the shards are numpy files uh so
3:16:58
and the shards are numpy files uh so just storing a numpy array and uh which
3:17:01
just storing a numpy array and uh which
3:17:01
just storing a numpy array and uh which is very very similar to torch
3:17:03
is very very similar to torch
3:17:03
is very very similar to torch tensors and the first Shard 0000 is a
3:17:07
tensors and the first Shard 0000 is a
3:17:07
tensors and the first Shard 0000 is a Val a validation Shard and all the other
3:17:09
Val a validation Shard and all the other
3:17:09
Val a validation Shard and all the other shards are uh training shards and as I
3:17:12
shards are uh training shards and as I
3:17:12
shards are uh training shards and as I mentioned they all have 100 million
3:17:14
mentioned they all have 100 million
3:17:14
mentioned they all have 100 million tokens in them exactly um and and that
3:17:17
tokens in them exactly um and and that
3:17:17
tokens in them exactly um and and that just makes it easier to work with as to
3:17:20
just makes it easier to work with as to
3:17:20
just makes it easier to work with as to Shard the files because if we just have
3:17:22
Shard the files because if we just have
3:17:22
Shard the files because if we just have a single massive file sometimes they can
3:17:24
a single massive file sometimes they can
3:17:24
a single massive file sometimes they can be hard to work with on the disk and so
3:17:26
be hard to work with on the disk and so
3:17:26
be hard to work with on the disk and so sharting it is just kind of um nicer
3:17:28
sharting it is just kind of um nicer
3:17:28
sharting it is just kind of um nicer from that
3:17:30
from that
3:17:30
from that perspective and uh yeah so we'll just
3:17:32
perspective and uh yeah so we'll just
3:17:32
perspective and uh yeah so we'll just let this run this will be probably um
3:17:36
let this run this will be probably um
3:17:36
let this run this will be probably um 30ish minutes or so and then we're going
3:17:38
30ish minutes or so and then we're going
3:17:38
30ish minutes or so and then we're going to come back to actually train on this
3:17:39
to come back to actually train on this
3:17:39
to come back to actually train on this data and we're going to be actually
3:17:41
data and we're going to be actually
3:17:41
data and we're going to be actually doing some legit pre-training in this
3:17:42
doing some legit pre-training in this
3:17:42
doing some legit pre-training in this case this is a good data set we're doing
3:17:45
case this is a good data set we're doing
3:17:45
case this is a good data set we're doing lots of tokens per second we have 8 gpus
3:17:48
lots of tokens per second we have 8 gpus
3:17:48
lots of tokens per second we have 8 gpus the code is ready and so we're actually
3:17:50
the code is ready and so we're actually
3:17:50
the code is ready and so we're actually going to be doing a serious training run
3:17:51
going to be doing a serious training run
3:17:52
going to be doing a serious training run so let's get P it back in a bit okay so
3:17:54
so let's get P it back in a bit okay so
3:17:54
so let's get P it back in a bit okay so we're back so uh if we LS edu fine web
3:17:58
we're back so uh if we LS edu fine web
3:17:58
we're back so uh if we LS edu fine web we see that there's now 100 charts in it
3:18:02
we see that there's now 100 charts in it
3:18:02
we see that there's now 100 charts in it um and that makes sense because each
3:18:03
um and that makes sense because each
3:18:03
um and that makes sense because each chart is 100 million tokens so 100
3:18:06
chart is 100 million tokens so 100
3:18:06
chart is 100 million tokens so 100 charts of that is 10 billion tokens in
3:18:08
charts of that is 10 billion tokens in
3:18:08
charts of that is 10 billion tokens in total now swinging over to the main file
3:18:11
total now swinging over to the main file
3:18:11
total now swinging over to the main file I made some adjustments to our data
3:18:12
I made some adjustments to our data
3:18:12
I made some adjustments to our data loader again and that's because we're
3:18:14
loader again and that's because we're
3:18:14
loader again and that's because we're not running with uh Shakespeare anymore
3:18:17
not running with uh Shakespeare anymore
3:18:17
not running with uh Shakespeare anymore we want to use the fine web shards and
3:18:20
we want to use the fine web shards and
3:18:20
we want to use the fine web shards and so you'll see some code here that
3:18:21
so you'll see some code here that
3:18:21
so you'll see some code here that additionally basically can load these
3:18:23
additionally basically can load these
3:18:23
additionally basically can load these shards uh we load the um un6 numpy file
3:18:28
shards uh we load the um un6 numpy file
3:18:28
shards uh we load the um un6 numpy file we convert it to a torch. long tensor
3:18:30
we convert it to a torch. long tensor
3:18:30
we convert it to a torch. long tensor which is what a lot of the layers up top
3:18:32
which is what a lot of the layers up top
3:18:32
which is what a lot of the layers up top expect by default and then here we're
3:18:35
expect by default and then here we're
3:18:35
expect by default and then here we're just enumerating all the shards I also
3:18:38
just enumerating all the shards I also
3:18:38
just enumerating all the shards I also added a split to data load of light so
3:18:40
added a split to data load of light so
3:18:40
added a split to data load of light so we can uh load the split train but also
3:18:42
we can uh load the split train but also
3:18:42
we can uh load the split train but also the split Val uh the zero
3:18:44
the split Val uh the zero
3:18:44
the split Val uh the zero split and then we can load the shards
3:18:47
split and then we can load the shards
3:18:47
split and then we can load the shards and then here we also have not just the
3:18:49
and then here we also have not just the
3:18:49
and then here we also have not just the current position now but also the
3:18:50
current position now but also the
3:18:51
current position now but also the current Shard so we have a position
3:18:53
current Shard so we have a position
3:18:53
current Shard so we have a position inside A Shard and then when we uh run
3:18:55
inside A Shard and then when we uh run
3:18:55
inside A Shard and then when we uh run out of tokens in A Single Shard we first
3:18:58
out of tokens in A Single Shard we first
3:18:58
out of tokens in A Single Shard we first Advance The Shard and loop if we need to
3:19:01
Advance The Shard and loop if we need to
3:19:01
Advance The Shard and loop if we need to and then we get the tokens and readjust
3:19:03
and then we get the tokens and readjust
3:19:03
and then we get the tokens and readjust the position so this data loader will
3:19:06
the position so this data loader will
3:19:06
the position so this data loader will now iterate all the shards as well so I
3:19:09
now iterate all the shards as well so I
3:19:09
now iterate all the shards as well so I Chang that and then the other thing that
3:19:11
Chang that and then the other thing that
3:19:11
Chang that and then the other thing that I did while uh the data was processing
3:19:14
I did while uh the data was processing
3:19:14
I did while uh the data was processing is our train loader now has split train
3:19:17
is our train loader now has split train
3:19:17
is our train loader now has split train of course and down here I set up some I
3:19:20
of course and down here I set up some I
3:19:20
of course and down here I set up some I set up some numbers
3:19:21
set up some numbers
3:19:21
set up some numbers so we are doing 2 to the
3:19:24
so we are doing 2 to the
3:19:24
so we are doing 2 to the 9 uh tokens per uh per um per step and
3:19:31
9 uh tokens per uh per um per step and
3:19:31
9 uh tokens per uh per um per step and we want to do roughly 10 billion tokens
3:19:35
we want to do roughly 10 billion tokens
3:19:35
we want to do roughly 10 billion tokens um because that's how many unique tokens
3:19:36
um because that's how many unique tokens
3:19:36
um because that's how many unique tokens we have so if we did 10 billion tokens
3:19:39
we have so if we did 10 billion tokens
3:19:39
we have so if we did 10 billion tokens then divide that by 29 we see that this
3:19:41
then divide that by 29 we see that this
3:19:41
then divide that by 29 we see that this is 1973 steps so that's where that's
3:19:44
is 1973 steps so that's where that's
3:19:44
is 1973 steps so that's where that's from and then the GPT three paper says
3:19:47
from and then the GPT three paper says
3:19:47
from and then the GPT three paper says that they warm up the learning rate over
3:19:49
that they warm up the learning rate over
3:19:49
that they warm up the learning rate over 375 million tokens so I came here and
3:19:53
375 million tokens so I came here and
3:19:53
375 million tokens so I came here and 375 E6 tokens divide uh 2 to the
3:19:57
375 E6 tokens divide uh 2 to the
3:19:57
375 E6 tokens divide uh 2 to the 19 is 715 steps so that's why warm-up
3:20:01
19 is 715 steps so that's why warm-up
3:20:01
19 is 715 steps so that's why warm-up steps is set to 715 so this will exactly
3:20:04
steps is set to 715 so this will exactly
3:20:04
steps is set to 715 so this will exactly match um the warm-up schedule that gpt3
3:20:07
match um the warm-up schedule that gpt3
3:20:07
match um the warm-up schedule that gpt3 used and I think 715 by the way is very
3:20:10
used and I think 715 by the way is very
3:20:10
used and I think 715 by the way is very uh mild and this could be made
3:20:12
uh mild and this could be made
3:20:12
uh mild and this could be made significantly more aggressive probably
3:20:13
significantly more aggressive probably
3:20:13
significantly more aggressive probably even like 100 is good enough um
3:20:17
even like 100 is good enough um
3:20:17
even like 100 is good enough um but it's okay let's leave it for now so
3:20:18
but it's okay let's leave it for now so
3:20:18
but it's okay let's leave it for now so that we have the exact hyper parameters
3:20:20
that we have the exact hyper parameters
3:20:20
that we have the exact hyper parameters of gpt3 so I fix that and then um that's
3:20:25
of gpt3 so I fix that and then um that's
3:20:25
of gpt3 so I fix that and then um that's pretty much it we can we can run so we
3:20:28
pretty much it we can we can run so we
3:20:28
pretty much it we can we can run so we have our script
3:20:29
have our script
3:20:29
have our script here and we can
3:20:32
here and we can
3:20:32
here and we can launch and actually sorry let me do one
3:20:34
launch and actually sorry let me do one
3:20:34
launch and actually sorry let me do one more
3:20:38
thing excuse
3:20:40
thing excuse
3:20:40
thing excuse me for my GPU I can actually fit more
3:20:43
me for my GPU I can actually fit more
3:20:43
me for my GPU I can actually fit more batch size and I believe I can fat I can
3:20:45
batch size and I believe I can fat I can
3:20:45
batch size and I believe I can fat I can fit 60 4 on my GPU as a micro bash size
3:20:49
fit 60 4 on my GPU as a micro bash size
3:20:50
fit 60 4 on my GPU as a micro bash size so let me try
3:20:54
that I could be misremembering but that
3:20:57
that I could be misremembering but that
3:20:57
that I could be misremembering but that means 64 * 124 per GPU and then we have
3:21:00
means 64 * 124 per GPU and then we have
3:21:00
means 64 * 124 per GPU and then we have a gpus so that means we would not even
3:21:02
a gpus so that means we would not even
3:21:02
a gpus so that means we would not even be doing gradient accumulation if this
3:21:04
be doing gradient accumulation if this
3:21:04
be doing gradient accumulation if this fits because uh this just multi
3:21:06
fits because uh this just multi
3:21:06
fits because uh this just multi multiplies out to uh the full total bat
3:21:09
multiplies out to uh the full total bat
3:21:09
multiplies out to uh the full total bat size so no gradient
3:21:12
size so no gradient
3:21:12
size so no gradient accumulation and that would run pretty
3:21:14
accumulation and that would run pretty
3:21:14
accumulation and that would run pretty quickly if that fits
3:21:26
let's go let's go I mean if this works
3:21:29
let's go let's go I mean if this works
3:21:29
let's go let's go I mean if this works then this is basically a serious
3:21:31
then this is basically a serious
3:21:31
then this is basically a serious pre-training run um we're not logging
3:21:33
pre-training run um we're not logging
3:21:34
pre-training run um we're not logging we're not evaluating the validation
3:21:35
we're not evaluating the validation
3:21:35
we're not evaluating the validation split we're not running any evaluations
3:21:37
split we're not running any evaluations
3:21:37
split we're not running any evaluations yet so it's not we haven't crossed our
3:21:39
yet so it's not we haven't crossed our
3:21:39
yet so it's not we haven't crossed our te's and dotted our eyes but uh if we
3:21:42
te's and dotted our eyes but uh if we
3:21:42
te's and dotted our eyes but uh if we let this run for a while we're going to
3:21:44
let this run for a while we're going to
3:21:44
let this run for a while we're going to actually get a pretty good model and the
3:21:46
actually get a pretty good model and the
3:21:46
actually get a pretty good model and the model that might even be on par with or
3:21:49
model that might even be on par with or
3:21:49
model that might even be on par with or better than gpt2 124 M okay so it looks
3:21:54
better than gpt2 124 M okay so it looks
3:21:54
better than gpt2 124 M okay so it looks like everything is going great we're
3:21:55
like everything is going great we're
3:21:55
like everything is going great we're processing 1.5 million tokens per
3:21:58
processing 1.5 million tokens per
3:21:58
processing 1.5 million tokens per second uh everything here looks good
3:22:03
second uh everything here looks good
3:22:03
second uh everything here looks good we're doing 330 milliseconds per
3:22:06
we're doing 330 milliseconds per
3:22:06
we're doing 330 milliseconds per iteration and we have to do a total
3:22:09
iteration and we have to do a total
3:22:09
iteration and we have to do a total of uh where are we printing that 1973 so
3:22:13
of uh where are we printing that 1973 so
3:22:13
of uh where are we printing that 1973 so 19073 times 0.33
3:22:17
19073 times 0.33
3:22:17
19073 times 0.33 is this many seconds this many minutes
3:22:20
is this many seconds this many minutes
3:22:20
is this many seconds this many minutes so this will run for 1.7
3:22:24
so this will run for 1.7
3:22:24
so this will run for 1.7 hours uh so one and a half hour run uh
3:22:28
hours uh so one and a half hour run uh
3:22:28
hours uh so one and a half hour run uh like this and uh we don't even have to
3:22:30
like this and uh we don't even have to
3:22:30
like this and uh we don't even have to use gradient accumulation which is nice
3:22:31
use gradient accumulation which is nice
3:22:31
use gradient accumulation which is nice and you might not have that luxury in
3:22:33
and you might not have that luxury in
3:22:33
and you might not have that luxury in your GPU in that case just start
3:22:35
your GPU in that case just start
3:22:35
your GPU in that case just start decreasing the batch size until things
3:22:36
decreasing the batch size until things
3:22:37
decreasing the batch size until things fit but keep it to nice
3:22:39
fit but keep it to nice
3:22:39
fit but keep it to nice numbers um so that's pretty exciting
3:22:42
numbers um so that's pretty exciting
3:22:42
numbers um so that's pretty exciting we're currently warming up the learning
3:22:43
we're currently warming up the learning
3:22:43
we're currently warming up the learning rate so you see that it's still very low
3:22:45
rate so you see that it's still very low
3:22:45
rate so you see that it's still very low one4 so this will ramp up over the next
3:22:48
one4 so this will ramp up over the next
3:22:48
one4 so this will ramp up over the next few steps all the way to 6 e
3:22:50
few steps all the way to 6 e
3:22:50
few steps all the way to 6 e Nega uh 4
3:22:53
Nega uh 4
3:22:53
Nega uh 4 here very cool so now what I'd like to
3:22:56
here very cool so now what I'd like to
3:22:56
here very cool so now what I'd like to do is uh let's cross the T and do our
3:22:58
do is uh let's cross the T and do our
3:22:58
do is uh let's cross the T and do our eyes let's evaluate on the validation
3:23:00
eyes let's evaluate on the validation
3:23:00
eyes let's evaluate on the validation split and let's try to figure out how we
3:23:02
split and let's try to figure out how we
3:23:02
split and let's try to figure out how we can run evals how we can do logging how
3:23:04
can run evals how we can do logging how
3:23:05
can run evals how we can do logging how we can visualize our losses and all the
3:23:07
we can visualize our losses and all the
3:23:07
we can visualize our losses and all the good stuff so let's get to that before
3:23:09
good stuff so let's get to that before
3:23:09
good stuff so let's get to that before we actually do the run okay so I've
3:23:11
we actually do the run okay so I've
3:23:11
we actually do the run okay so I've adjusted the code so that we're
3:23:13
adjusted the code so that we're
3:23:13
adjusted the code so that we're evaluating on the validation split so
3:23:15
evaluating on the validation split so
3:23:15
evaluating on the validation split so creating the Val loader just by passing
3:23:16
creating the Val loader just by passing
3:23:17
creating the Val loader just by passing in Split equals Val that will basically
3:23:19
in Split equals Val that will basically
3:23:19
in Split equals Val that will basically create a data loader just for the uh
3:23:21
create a data loader just for the uh
3:23:21
create a data loader just for the uh validation
3:23:22
validation
3:23:22
validation Shard um the other thing I did is in the
3:23:25
Shard um the other thing I did is in the
3:23:25
Shard um the other thing I did is in the data loader I introduced a new function
3:23:27
data loader I introduced a new function
3:23:27
data loader I introduced a new function reset which is called at init and it
3:23:29
reset which is called at init and it
3:23:29
reset which is called at init and it basically resets the data loader and
3:23:31
basically resets the data loader and
3:23:31
basically resets the data loader and that is very useful because when we come
3:23:34
that is very useful because when we come
3:23:34
that is very useful because when we come to the main training Loop now so this is
3:23:37
to the main training Loop now so this is
3:23:37
to the main training Loop now so this is the code that I've added and basically
3:23:39
the code that I've added and basically
3:23:39
the code that I've added and basically every 100th iteration including the
3:23:41
every 100th iteration including the
3:23:41
every 100th iteration including the zeroth iteration we put the model into
3:23:44
zeroth iteration we put the model into
3:23:44
zeroth iteration we put the model into evaluation mode we reset the Val loader
3:23:47
evaluation mode we reset the Val loader
3:23:47
evaluation mode we reset the Val loader and then um no gradients involved we're
3:23:50
and then um no gradients involved we're
3:23:50
and then um no gradients involved we're going to
3:23:52
going to
3:23:52
going to basically accumulate the gradients over
3:23:54
basically accumulate the gradients over
3:23:54
basically accumulate the gradients over say 20 steps and then average it all up
3:23:58
say 20 steps and then average it all up
3:23:58
say 20 steps and then average it all up and print out the validation loss and so
3:24:01
and print out the validation loss and so
3:24:01
and print out the validation loss and so that basically is the exact same logic
3:24:03
that basically is the exact same logic
3:24:03
that basically is the exact same logic as the training Loop roughly but there's
3:24:06
as the training Loop roughly but there's
3:24:06
as the training Loop roughly but there's no loss that backward it's only
3:24:07
no loss that backward it's only
3:24:07
no loss that backward it's only inference we're just measuring the loss
3:24:09
inference we're just measuring the loss
3:24:09
inference we're just measuring the loss we're adding it up everything else
3:24:11
we're adding it up everything else
3:24:11
we're adding it up everything else otherwise applies and is exactly as
3:24:13
otherwise applies and is exactly as
3:24:13
otherwise applies and is exactly as we've seen it before and so this will
3:24:15
we've seen it before and so this will
3:24:15
we've seen it before and so this will print the validation laws
3:24:16
print the validation laws
3:24:16
print the validation laws um every 100th iteration including on
3:24:19
um every 100th iteration including on
3:24:19
um every 100th iteration including on the very first
3:24:20
the very first
3:24:20
the very first iteration uh so that's nice that will
3:24:23
iteration uh so that's nice that will
3:24:23
iteration uh so that's nice that will tell us some amount some a little bit
3:24:24
tell us some amount some a little bit
3:24:25
tell us some amount some a little bit about how much we're overfitting that
3:24:27
about how much we're overfitting that
3:24:27
about how much we're overfitting that said like uh we have roughly Infinity
3:24:29
said like uh we have roughly Infinity
3:24:29
said like uh we have roughly Infinity data so we're mostly expecting our train
3:24:31
data so we're mostly expecting our train
3:24:31
data so we're mostly expecting our train and Val loss to be about the same but
3:24:33
and Val loss to be about the same but
3:24:33
and Val loss to be about the same but the other reason I'm kind of interested
3:24:34
the other reason I'm kind of interested
3:24:35
the other reason I'm kind of interested in this is because we can take the GPT
3:24:36
in this is because we can take the GPT
3:24:36
in this is because we can take the GPT 2124m as openi released it we can
3:24:39
2124m as openi released it we can
3:24:39
2124m as openi released it we can initialize from it and we can basically
3:24:41
initialize from it and we can basically
3:24:41
initialize from it and we can basically see what kind of loss it achieves on the
3:24:43
see what kind of loss it achieves on the
3:24:43
see what kind of loss it achieves on the validation loss as well and that gives
3:24:45
validation loss as well and that gives
3:24:45
validation loss as well and that gives us kind of an indication as to uh how
3:24:47
us kind of an indication as to uh how
3:24:47
us kind of an indication as to uh how much that model would generalize to 124
3:24:49
much that model would generalize to 124
3:24:49
much that model would generalize to 124 M but it's not an sorry to fine web edu
3:24:52
M but it's not an sorry to fine web edu
3:24:52
M but it's not an sorry to fine web edu validation split that said it's not a
3:24:55
validation split that said it's not a
3:24:55
validation split that said it's not a super fair comparison to gpt2 because it
3:24:56
super fair comparison to gpt2 because it
3:24:57
super fair comparison to gpt2 because it was trained on a very different data
3:24:58
was trained on a very different data
3:24:58
was trained on a very different data distribution but it's still kind of like
3:25:00
distribution but it's still kind of like
3:25:00
distribution but it's still kind of like an interesting data point and in any
3:25:02
an interesting data point and in any
3:25:02
an interesting data point and in any case you would always want to have a
3:25:03
case you would always want to have a
3:25:03
case you would always want to have a validation split in a training run like
3:25:06
validation split in a training run like
3:25:06
validation split in a training run like this so that you can make sure that you
3:25:08
this so that you can make sure that you
3:25:08
this so that you can make sure that you are not um overfitting and this is
3:25:11
are not um overfitting and this is
3:25:11
are not um overfitting and this is especially a concern if we were to make
3:25:13
especially a concern if we were to make
3:25:13
especially a concern if we were to make more Epoch in our training data um so
3:25:16
more Epoch in our training data um so
3:25:16
more Epoch in our training data um so for example right now we're just doing a
3:25:18
for example right now we're just doing a
3:25:18
for example right now we're just doing a single Epoch but if we get to a point
3:25:20
single Epoch but if we get to a point
3:25:20
single Epoch but if we get to a point where we want to train on 10 epochs or
3:25:21
where we want to train on 10 epochs or
3:25:21
where we want to train on 10 epochs or something like that we would be really
3:25:23
something like that we would be really
3:25:23
something like that we would be really careful with maybe we are memorizing
3:25:26
careful with maybe we are memorizing
3:25:26
careful with maybe we are memorizing that data too much if we have a big
3:25:28
that data too much if we have a big
3:25:28
that data too much if we have a big enough model and our validation split
3:25:30
enough model and our validation split
3:25:30
enough model and our validation split would be one way to tell whether that is
3:25:32
would be one way to tell whether that is
3:25:32
would be one way to tell whether that is happening okay and in addition to that
3:25:34
happening okay and in addition to that
3:25:34
happening okay and in addition to that if you remember at bottom of our script
3:25:36
if you remember at bottom of our script
3:25:36
if you remember at bottom of our script we had all of this orphaned code for
3:25:37
we had all of this orphaned code for
3:25:37
we had all of this orphaned code for sampling from way back when so I deleted
3:25:40
sampling from way back when so I deleted
3:25:40
sampling from way back when so I deleted that code and I moved it up um to here
3:25:43
that code and I moved it up um to here
3:25:43
that code and I moved it up um to here so once in a while we simply value
3:25:45
so once in a while we simply value
3:25:45
so once in a while we simply value validation
3:25:46
validation
3:25:46
validation once in a while we sample we generate
3:25:49
once in a while we sample we generate
3:25:49
once in a while we sample we generate samples and then uh we do that only
3:25:52
samples and then uh we do that only
3:25:52
samples and then uh we do that only every 100 steps and we train on every
3:25:55
every 100 steps and we train on every
3:25:55
every 100 steps and we train on every single step so that's how I have a
3:25:56
single step so that's how I have a
3:25:56
single step so that's how I have a structure right now and I've been
3:25:58
structure right now and I've been
3:25:58
structure right now and I've been running this for 10,000 iterations so
3:26:00
running this for 10,000 iterations so
3:26:00
running this for 10,000 iterations so here are some samples on neration
3:26:02
here are some samples on neration
3:26:02
here are some samples on neration 1,000
3:26:04
1,000
3:26:05
1,000 um hello I'm a language model and I'm
3:26:07
um hello I'm a language model and I'm
3:26:07
um hello I'm a language model and I'm not able to get more
3:26:09
not able to get more
3:26:09
not able to get more creative I'm a language model and
3:26:10
creative I'm a language model and
3:26:10
creative I'm a language model and languages file you're learning about
3:26:12
languages file you're learning about
3:26:12
languages file you're learning about here is or is the beginning of a
3:26:14
here is or is the beginning of a
3:26:14
here is or is the beginning of a computer
3:26:16
computer
3:26:16
computer okay so this is all like pretty uh this
3:26:19
okay so this is all like pretty uh this
3:26:19
okay so this is all like pretty uh this is still a garble uh but we're only at
3:26:21
is still a garble uh but we're only at
3:26:21
is still a garble uh but we're only at ration 1,000 and we've only just barely
3:26:24
ration 1,000 and we've only just barely
3:26:24
ration 1,000 and we've only just barely reached maximum learning rate uh so this
3:26:26
reached maximum learning rate uh so this
3:26:26
reached maximum learning rate uh so this is still learning uh we're about to get
3:26:28
is still learning uh we're about to get
3:26:28
is still learning uh we're about to get some more samples coming up in
3:26:32
1,00 okay
3:26:35
1,00 okay
3:26:35
1,00 okay um okay this is you know the model is
3:26:38
um okay this is you know the model is
3:26:38
um okay this is you know the model is still is still a young baby okay so uh
3:26:42
still is still a young baby okay so uh
3:26:42
still is still a young baby okay so uh basically all of this sampling code that
3:26:44
basically all of this sampling code that
3:26:44
basically all of this sampling code that I've put here everything should be
3:26:45
I've put here everything should be
3:26:45
I've put here everything should be familiar with to you and came from
3:26:47
familiar with to you and came from
3:26:47
familiar with to you and came from before the only thing that I did is I
3:26:49
before the only thing that I did is I
3:26:49
before the only thing that I did is I created a generator object in pytorch so
3:26:52
created a generator object in pytorch so
3:26:52
created a generator object in pytorch so that I have a direct control over the
3:26:54
that I have a direct control over the
3:26:54
that I have a direct control over the sampling of the random numbers don't
3:26:56
sampling of the random numbers don't
3:26:56
sampling of the random numbers don't because I don't want to impact the RNG
3:26:58
because I don't want to impact the RNG
3:26:58
because I don't want to impact the RNG state of the random number generator
3:27:00
state of the random number generator
3:27:00
state of the random number generator that is the global one used for training
3:27:02
that is the global one used for training
3:27:02
that is the global one used for training I want this to be completely outside of
3:27:04
I want this to be completely outside of
3:27:04
I want this to be completely outside of the training Loop and so I'm using a
3:27:07
the training Loop and so I'm using a
3:27:07
the training Loop and so I'm using a special sampling RNG and then I make
3:27:09
special sampling RNG and then I make
3:27:09
special sampling RNG and then I make sure to seed it that every single rank
3:27:12
sure to seed it that every single rank
3:27:12
sure to seed it that every single rank has a different seed and then I pass in
3:27:14
has a different seed and then I pass in
3:27:14
has a different seed and then I pass in here where we sort of consumer in the
3:27:16
here where we sort of consumer in the
3:27:17
here where we sort of consumer in the numbers in multinomial where the
3:27:18
numbers in multinomial where the
3:27:18
numbers in multinomial where the sampling happens I make sure to pass in
3:27:20
sampling happens I make sure to pass in
3:27:20
sampling happens I make sure to pass in the generator object there otherwise
3:27:22
the generator object there otherwise
3:27:22
the generator object there otherwise this is identical uh now the other thing
3:27:25
this is identical uh now the other thing
3:27:25
this is identical uh now the other thing is um you'll notice that we're running a
3:27:27
is um you'll notice that we're running a
3:27:27
is um you'll notice that we're running a bit slower that's because I actually had
3:27:29
bit slower that's because I actually had
3:27:29
bit slower that's because I actually had to disable torch. compile to get this to
3:27:31
to disable torch. compile to get this to
3:27:32
to disable torch. compile to get this to sample and um so we're running a bit
3:27:34
sample and um so we're running a bit
3:27:34
sample and um so we're running a bit slower so for some reason it works with
3:27:36
slower so for some reason it works with
3:27:36
slower so for some reason it works with no torch compile but when I torch
3:27:37
no torch compile but when I torch
3:27:37
no torch compile but when I torch compile my model I get a really scary
3:27:39
compile my model I get a really scary
3:27:39
compile my model I get a really scary error from pytorch and I have no idea
3:27:41
error from pytorch and I have no idea
3:27:41
error from pytorch and I have no idea how to resolve it right now so probably
3:27:43
how to resolve it right now so probably
3:27:43
how to resolve it right now so probably by the time you see this code released
3:27:45
by the time you see this code released
3:27:45
by the time you see this code released or something like that maybe it's fixed
3:27:47
or something like that maybe it's fixed
3:27:47
or something like that maybe it's fixed but for now I'm just going to do end
3:27:49
but for now I'm just going to do end
3:27:49
but for now I'm just going to do end false um and I'm going to bring back
3:27:51
false um and I'm going to bring back
3:27:51
false um and I'm going to bring back toor compile and you're not going to get
3:27:54
toor compile and you're not going to get
3:27:54
toor compile and you're not going to get samples and I I think I'll fix this
3:27:56
samples and I I think I'll fix this
3:27:56
samples and I I think I'll fix this later uh by the way um I will be
3:27:59
later uh by the way um I will be
3:27:59
later uh by the way um I will be releasing all this code and actually
3:28:01
releasing all this code and actually
3:28:01
releasing all this code and actually I've been very careful about making get
3:28:02
I've been very careful about making get
3:28:03
I've been very careful about making get commits every time we add something and
3:28:05
commits every time we add something and
3:28:05
commits every time we add something and so I'm going to release the entire repo
3:28:07
so I'm going to release the entire repo
3:28:07
so I'm going to release the entire repo that starts completely from scratch all
3:28:09
that starts completely from scratch all
3:28:09
that starts completely from scratch all the way to uh now and after this as well
3:28:12
the way to uh now and after this as well
3:28:12
the way to uh now and after this as well and so everything should be exactly
3:28:13
and so everything should be exactly
3:28:13
and so everything should be exactly documented in the git commit history um
3:28:16
documented in the git commit history um
3:28:16
documented in the git commit history um um and so I think that will be nice so
3:28:18
um and so I think that will be nice so
3:28:19
um and so I think that will be nice so hopefully by the time you go to GitHub
3:28:20
hopefully by the time you go to GitHub
3:28:20
hopefully by the time you go to GitHub uh this is removed and it's working and
3:28:22
uh this is removed and it's working and
3:28:22
uh this is removed and it's working and I will have fixed the bug okay so I have
3:28:24
I will have fixed the bug okay so I have
3:28:24
I will have fixed the bug okay so I have the optimization running here and it's
3:28:26
the optimization running here and it's
3:28:26
the optimization running here and it's stepping and we're on step 6,000 or so
3:28:28
stepping and we're on step 6,000 or so
3:28:28
stepping and we're on step 6,000 or so so we're about 30% through training now
3:28:31
so we're about 30% through training now
3:28:31
so we're about 30% through training now while this is training I would like to
3:28:32
while this is training I would like to
3:28:32
while this is training I would like to introduce one evaluation that we're
3:28:34
introduce one evaluation that we're
3:28:34
introduce one evaluation that we're going to use to supplement the
3:28:35
going to use to supplement the
3:28:35
going to use to supplement the validation set and that is the H swag
3:28:38
validation set and that is the H swag
3:28:38
validation set and that is the H swag eval so hos swag comes from this paper
3:28:42
eval so hos swag comes from this paper
3:28:42
eval so hos swag comes from this paper back in 2019 so it's a 5-year-old eval
3:28:44
back in 2019 so it's a 5-year-old eval
3:28:44
back in 2019 so it's a 5-year-old eval now and the way H swag works is there is
3:28:47
now and the way H swag works is there is
3:28:47
now and the way H swag works is there is basically a sentence completion data set
3:28:50
basically a sentence completion data set
3:28:50
basically a sentence completion data set so it's a multiple choice for every one
3:28:52
so it's a multiple choice for every one
3:28:52
so it's a multiple choice for every one of these questions we have uh basically
3:28:54
of these questions we have uh basically
3:28:54
of these questions we have uh basically a shared context like a woman is outside
3:28:57
a shared context like a woman is outside
3:28:57
a shared context like a woman is outside with a bucket and a dog the dog is
3:28:59
with a bucket and a dog the dog is
3:28:59
with a bucket and a dog the dog is running around trying to avoid bath she
3:29:02
running around trying to avoid bath she
3:29:02
running around trying to avoid bath she a Rises the bucket off with soap and
3:29:04
a Rises the bucket off with soap and
3:29:04
a Rises the bucket off with soap and blow dry the dog's head B uses a hose to
3:29:08
blow dry the dog's head B uses a hose to
3:29:08
blow dry the dog's head B uses a hose to keep it from getting soapy C gets the
3:29:11
keep it from getting soapy C gets the
3:29:11
keep it from getting soapy C gets the dog wet and it runs away again or D gets
3:29:14
dog wet and it runs away again or D gets
3:29:14
dog wet and it runs away again or D gets into a bathtub with the dog
3:29:16
into a bathtub with the dog
3:29:16
into a bathtub with the dog and so basically the idea is that these
3:29:19
and so basically the idea is that these
3:29:19
and so basically the idea is that these multiple choice are constructed so that
3:29:22
multiple choice are constructed so that
3:29:22
multiple choice are constructed so that one of them is a natural continuation of
3:29:25
one of them is a natural continuation of
3:29:25
one of them is a natural continuation of the um sentence and the others are
3:29:30
the um sentence and the others are
3:29:30
the um sentence and the others are not and uh the others might not make
3:29:32
not and uh the others might not make
3:29:32
not and uh the others might not make sense like uses the host to keep it from
3:29:34
sense like uses the host to keep it from
3:29:34
sense like uses the host to keep it from getting soaped that makes no sense and
3:29:36
getting soaped that makes no sense and
3:29:36
getting soaped that makes no sense and so what happens is that models that are
3:29:38
so what happens is that models that are
3:29:38
so what happens is that models that are not trained very well are not able to
3:29:40
not trained very well are not able to
3:29:40
not trained very well are not able to tell these apart but models that have a
3:29:43
tell these apart but models that have a
3:29:43
tell these apart but models that have a lot of World Knowledge and can tell uh
3:29:45
lot of World Knowledge and can tell uh
3:29:45
lot of World Knowledge and can tell uh which um and can tell a lot about the
3:29:48
which um and can tell a lot about the
3:29:48
which um and can tell a lot about the world will be able to create these
3:29:50
world will be able to create these
3:29:50
world will be able to create these completions and these sentences are
3:29:52
completions and these sentences are
3:29:52
completions and these sentences are sourced from activity net and from Wiki
3:29:55
sourced from activity net and from Wiki
3:29:55
sourced from activity net and from Wiki how and at the bottom of the uh
3:30:00
how and at the bottom of the uh
3:30:00
how and at the bottom of the uh paper there's kind of like a cool chart
3:30:03
paper there's kind of like a cool chart
3:30:03
paper there's kind of like a cool chart of the kinds of domains in Wiki house so
3:30:05
of the kinds of domains in Wiki house so
3:30:05
of the kinds of domains in Wiki house so there's a lot of sentences from
3:30:07
there's a lot of sentences from
3:30:07
there's a lot of sentences from computers and electronics and Homes and
3:30:09
computers and electronics and Homes and
3:30:09
computers and electronics and Homes and Garden and it has kind of a broad
3:30:11
Garden and it has kind of a broad
3:30:11
Garden and it has kind of a broad coverage of the kinds of things you need
3:30:13
coverage of the kinds of things you need
3:30:13
coverage of the kinds of things you need to know about the world in order to find
3:30:15
to know about the world in order to find
3:30:15
to know about the world in order to find the most likely completion and um the
3:30:19
the most likely completion and um the
3:30:19
the most likely completion and um the identity of that of that completion one
3:30:22
identity of that of that completion one
3:30:22
identity of that of that completion one more thing that's kind of interesting
3:30:23
more thing that's kind of interesting
3:30:23
more thing that's kind of interesting about H swag is the way it was
3:30:25
about H swag is the way it was
3:30:25
about H swag is the way it was constructed is that the incorrect um
3:30:28
constructed is that the incorrect um
3:30:28
constructed is that the incorrect um options are deliberately um
3:30:32
options are deliberately um
3:30:32
options are deliberately um adversarially sourced so they're not
3:30:34
adversarially sourced so they're not
3:30:34
adversarially sourced so they're not just random sentences they're actually
3:30:36
just random sentences they're actually
3:30:37
just random sentences they're actually sentences generated by language models
3:30:39
sentences generated by language models
3:30:39
sentences generated by language models and they're generated in such a way that
3:30:41
and they're generated in such a way that
3:30:41
and they're generated in such a way that language models basically find them
3:30:42
language models basically find them
3:30:42
language models basically find them difficult but humans find them easy and
3:30:45
difficult but humans find them easy and
3:30:45
difficult but humans find them easy and so they mentioned that humans have a 95%
3:30:47
so they mentioned that humans have a 95%
3:30:47
so they mentioned that humans have a 95% accuracy on this set but at the time the
3:30:49
accuracy on this set but at the time the
3:30:49
accuracy on this set but at the time the state-of-the-art language models had
3:30:51
state-of-the-art language models had
3:30:51
state-of-the-art language models had only 48% and so at the time this was a
3:30:54
only 48% and so at the time this was a
3:30:54
only 48% and so at the time this was a good Benchmark now you can read the
3:30:57
good Benchmark now you can read the
3:30:57
good Benchmark now you can read the details of this paper to to learn more
3:30:59
details of this paper to to learn more
3:30:59
details of this paper to to learn more um the thing to point out though is that
3:31:01
um the thing to point out though is that
3:31:01
um the thing to point out though is that this is 5 years ago and since then what
3:31:03
this is 5 years ago and since then what
3:31:03
this is 5 years ago and since then what happened to H swag is that it's been
3:31:05
happened to H swag is that it's been
3:31:05
happened to H swag is that it's been totally just uh
3:31:08
totally just uh
3:31:08
totally just uh um solved and so now the language models
3:31:11
um solved and so now the language models
3:31:11
um solved and so now the language models here are 96% so basically the 4% the
3:31:14
here are 96% so basically the 4% the
3:31:14
here are 96% so basically the 4% the last 4% is probably errors in the data
3:31:16
last 4% is probably errors in the data
3:31:16
last 4% is probably errors in the data set or the questions are really really
3:31:18
set or the questions are really really
3:31:18
set or the questions are really really hard and so basically this data set is
3:31:20
hard and so basically this data set is
3:31:20
hard and so basically this data set is kind of crushed with respect to language
3:31:22
kind of crushed with respect to language
3:31:22
kind of crushed with respect to language models but back then the best language
3:31:23
models but back then the best language
3:31:23
models but back then the best language model was only at about 50% uh but this
3:31:27
model was only at about 50% uh but this
3:31:27
model was only at about 50% uh but this is how far things got but still the the
3:31:30
is how far things got but still the the
3:31:30
is how far things got but still the the reason people like H swag and it's not
3:31:33
reason people like H swag and it's not
3:31:33
reason people like H swag and it's not used by the way in gpt2 but in gpt3
3:31:37
used by the way in gpt2 but in gpt3
3:31:37
used by the way in gpt2 but in gpt3 there is H swag eval and lots of people
3:31:39
there is H swag eval and lots of people
3:31:39
there is H swag eval and lots of people use H
3:31:41
use H
3:31:41
use H swag and so for gpt3 we have results
3:31:45
swag and so for gpt3 we have results
3:31:45
swag and so for gpt3 we have results here
3:31:46
here
3:31:46
here that are cited so we know what percent
3:31:48
that are cited so we know what percent
3:31:48
that are cited so we know what percent accuracies gpt3 um attains at all these
3:31:51
accuracies gpt3 um attains at all these
3:31:51
accuracies gpt3 um attains at all these different model checkpoints for H swag
3:31:54
different model checkpoints for H swag
3:31:54
different model checkpoints for H swag eval and the reason people like it is
3:31:56
eval and the reason people like it is
3:31:56
eval and the reason people like it is because H swag is a smooth eval and it
3:31:59
because H swag is a smooth eval and it
3:31:59
because H swag is a smooth eval and it is an eval that offers quote unquote
3:32:01
is an eval that offers quote unquote
3:32:01
is an eval that offers quote unquote early signal uh so early signal means
3:32:04
early signal uh so early signal means
3:32:04
early signal uh so early signal means that even small language models are
3:32:06
that even small language models are
3:32:06
that even small language models are going to start at the random chance of
3:32:08
going to start at the random chance of
3:32:08
going to start at the random chance of 25% but they're going to slowly improve
3:32:11
25% but they're going to slowly improve
3:32:11
25% but they're going to slowly improve and you're going to see 25 26 27 Etc and
3:32:15
and you're going to see 25 26 27 Etc and
3:32:15
and you're going to see 25 26 27 Etc and uh you can see slow Improvement even
3:32:17
uh you can see slow Improvement even
3:32:17
uh you can see slow Improvement even when the models are very small and it's
3:32:19
when the models are very small and it's
3:32:19
when the models are very small and it's very early so it's smooth it has early
3:32:23
very early so it's smooth it has early
3:32:23
very early so it's smooth it has early signal and um it's been around for a
3:32:26
signal and um it's been around for a
3:32:26
signal and um it's been around for a long time so that's why people kind of
3:32:28
long time so that's why people kind of
3:32:28
long time so that's why people kind of like this
3:32:29
like this
3:32:29
like this eval uh now the way that we're going to
3:32:32
eval uh now the way that we're going to
3:32:32
eval uh now the way that we're going to evaluate this is as
3:32:34
evaluate this is as
3:32:34
evaluate this is as follows as I mentioned we have a shared
3:32:37
follows as I mentioned we have a shared
3:32:37
follows as I mentioned we have a shared context and this is kind of like a
3:32:39
context and this is kind of like a
3:32:39
context and this is kind of like a multiple choice task but instead of
3:32:41
multiple choice task but instead of
3:32:41
multiple choice task but instead of giving the model a multiple choice
3:32:42
giving the model a multiple choice
3:32:42
giving the model a multiple choice question and asking it for A B C or D uh
3:32:46
question and asking it for A B C or D uh
3:32:46
question and asking it for A B C or D uh we can't do that because these models
3:32:47
we can't do that because these models
3:32:47
we can't do that because these models when they are so small as we are seeing
3:32:49
when they are so small as we are seeing
3:32:49
when they are so small as we are seeing here the models can't actually do
3:32:51
here the models can't actually do
3:32:51
here the models can't actually do multiple choice they don't understand
3:32:53
multiple choice they don't understand
3:32:53
multiple choice they don't understand the concept of associating a label to
3:32:55
the concept of associating a label to
3:32:55
the concept of associating a label to one of the options of multiple choice uh
3:32:58
one of the options of multiple choice uh
3:32:58
one of the options of multiple choice uh they don't understand that so we have to
3:32:59
they don't understand that so we have to
3:32:59
they don't understand that so we have to give it to them in a native form and the
3:33:01
give it to them in a native form and the
3:33:01
give it to them in a native form and the native form is a token completion so
3:33:05
native form is a token completion so
3:33:05
native form is a token completion so here's what we do we construct a batch
3:33:06
here's what we do we construct a batch
3:33:06
here's what we do we construct a batch of four rows and uh T tokens whatever
3:33:10
of four rows and uh T tokens whatever
3:33:10
of four rows and uh T tokens whatever that t happens to be then the shared
3:33:13
that t happens to be then the shared
3:33:13
that t happens to be then the shared context that is basically the context
3:33:15
context that is basically the context
3:33:15
context that is basically the context for the for choices the tokens of that
3:33:17
for the for choices the tokens of that
3:33:17
for the for choices the tokens of that are shared across all of the rows and
3:33:20
are shared across all of the rows and
3:33:20
are shared across all of the rows and then we have the four options so we kind
3:33:22
then we have the four options so we kind
3:33:22
then we have the four options so we kind of like lay them out and then only one
3:33:25
of like lay them out and then only one
3:33:25
of like lay them out and then only one of the options is correct in this case
3:33:26
of the options is correct in this case
3:33:26
of the options is correct in this case label three option three and so um this
3:33:30
label three option three and so um this
3:33:30
label three option three and so um this is the correct option and option one two
3:33:32
is the correct option and option one two
3:33:32
is the correct option and option one two and for are
3:33:33
and for are
3:33:33
and for are incorrect now these options might be of
3:33:36
incorrect now these options might be of
3:33:36
incorrect now these options might be of different lengths so what we do is we
3:33:38
different lengths so what we do is we
3:33:38
different lengths so what we do is we sort of like take the longest length and
3:33:40
sort of like take the longest length and
3:33:40
sort of like take the longest length and that's the size of the batch B BYT and
3:33:42
that's the size of the batch B BYT and
3:33:42
that's the size of the batch B BYT and then some of these uh here are going to
3:33:45
then some of these uh here are going to
3:33:45
then some of these uh here are going to be pded Dimensions so they're going to
3:33:47
be pded Dimensions so they're going to
3:33:47
be pded Dimensions so they're going to be unused and so we need the tokens we
3:33:51
be unused and so we need the tokens we
3:33:51
be unused and so we need the tokens we need the correct label and we need a
3:33:53
need the correct label and we need a
3:33:53
need the correct label and we need a mask that tells us which tokens are
3:33:55
mask that tells us which tokens are
3:33:55
mask that tells us which tokens are active and the mask is then zero for
3:33:58
active and the mask is then zero for
3:33:58
active and the mask is then zero for these uh padded areas so that's how we
3:34:01
these uh padded areas so that's how we
3:34:01
these uh padded areas so that's how we construct these batches and then in
3:34:04
construct these batches and then in
3:34:04
construct these batches and then in order to get the language model to
3:34:05
order to get the language model to
3:34:05
order to get the language model to predict A B C or D the way this works is
3:34:08
predict A B C or D the way this works is
3:34:08
predict A B C or D the way this works is basically we're just going to look at
3:34:10
basically we're just going to look at
3:34:10
basically we're just going to look at the tokens their probabilities and we're
3:34:12
the tokens their probabilities and we're
3:34:12
the tokens their probabilities and we're going to pick the option that gets the
3:34:15
going to pick the option that gets the
3:34:15
going to pick the option that gets the lowest or the highest average
3:34:18
lowest or the highest average
3:34:18
lowest or the highest average probability for the token so for the
3:34:22
probability for the token so for the
3:34:22
probability for the token so for the tokens because that is the most likely
3:34:25
tokens because that is the most likely
3:34:25
tokens because that is the most likely completion according to the language
3:34:26
completion according to the language
3:34:27
completion according to the language model so we're just going to look at the
3:34:29
model so we're just going to look at the
3:34:29
model so we're just going to look at the um probabilities here and average them
3:34:33
um probabilities here and average them
3:34:33
um probabilities here and average them up across the options and pick the one
3:34:35
up across the options and pick the one
3:34:35
up across the options and pick the one with the highest probability roughly
3:34:38
with the highest probability roughly
3:34:38
with the highest probability roughly speaking so this is how we're going to
3:34:40
speaking so this is how we're going to
3:34:40
speaking so this is how we're going to do H swag
3:34:42
do H swag
3:34:42
do H swag um and this is I believe also how uh
3:34:45
um and this is I believe also how uh
3:34:46
um and this is I believe also how uh gpt3 did it um this is how gpt3 did it
3:34:50
gpt3 did it um this is how gpt3 did it
3:34:50
gpt3 did it um this is how gpt3 did it as far as I know but you should note
3:34:52
as far as I know but you should note
3:34:52
as far as I know but you should note that some of the other evals where you
3:34:54
that some of the other evals where you
3:34:54
that some of the other evals where you might see H swag may not do it this way
3:34:57
might see H swag may not do it this way
3:34:57
might see H swag may not do it this way they may do it in a multiple choice
3:34:58
they may do it in a multiple choice
3:34:58
they may do it in a multiple choice format where you sort of uh give the the
3:35:00
format where you sort of uh give the the
3:35:00
format where you sort of uh give the the context a single time and then the four
3:35:02
context a single time and then the four
3:35:02
context a single time and then the four completions and so the model is able to
3:35:04
completions and so the model is able to
3:35:05
completions and so the model is able to see all the four options before it picks
3:35:06
see all the four options before it picks
3:35:07
see all the four options before it picks the best possible option and that's
3:35:08
the best possible option and that's
3:35:08
the best possible option and that's actually an easier task for a model
3:35:11
actually an easier task for a model
3:35:11
actually an easier task for a model because you get to see the other options
3:35:12
because you get to see the other options
3:35:12
because you get to see the other options when you're picking your choice um but
3:35:15
when you're picking your choice um but
3:35:15
when you're picking your choice um but unfortunately models at our size can't
3:35:17
unfortunately models at our size can't
3:35:17
unfortunately models at our size can't do that only models at a bigger size are
3:35:20
do that only models at a bigger size are
3:35:20
do that only models at a bigger size are able to do that and so our models are
3:35:22
able to do that and so our models are
3:35:22
able to do that and so our models are actually slightly handicapped in this
3:35:23
actually slightly handicapped in this
3:35:23
actually slightly handicapped in this way that they are not going to see the
3:35:25
way that they are not going to see the
3:35:25
way that they are not going to see the other options they're only going to see
3:35:27
other options they're only going to see
3:35:27
other options they're only going to see one option at a time and they just have
3:35:29
one option at a time and they just have
3:35:29
one option at a time and they just have to assign probabilities and the correct
3:35:31
to assign probabilities and the correct
3:35:31
to assign probabilities and the correct option has to win out in this metric all
3:35:34
option has to win out in this metric all
3:35:34
option has to win out in this metric all right so let's now implement this very
3:35:36
right so let's now implement this very
3:35:36
right so let's now implement this very briefly and incorporate it into our
3:35:38
briefly and incorporate it into our
3:35:38
briefly and incorporate it into our script okay so what I've done here is
3:35:40
script okay so what I've done here is
3:35:40
script okay so what I've done here is I've introduced a new file called hell
3:35:42
I've introduced a new file called hell
3:35:42
I've introduced a new file called hell swag. py that you can take a look into
3:35:45
swag. py that you can take a look into
3:35:45
swag. py that you can take a look into and I'm not going to to step through all
3:35:46
and I'm not going to to step through all
3:35:46
and I'm not going to to step through all of it because uh this is not exactly
3:35:48
of it because uh this is not exactly
3:35:48
of it because uh this is not exactly like deep code deep code it's kind of
3:35:51
like deep code deep code it's kind of
3:35:51
like deep code deep code it's kind of like a little bit tedious honestly
3:35:53
like a little bit tedious honestly
3:35:53
like a little bit tedious honestly because what's happening is I'm
3:35:54
because what's happening is I'm
3:35:54
because what's happening is I'm downloading hsac from GitHub and I'm
3:35:56
downloading hsac from GitHub and I'm
3:35:56
downloading hsac from GitHub and I'm rendering all of its examples and there
3:35:58
rendering all of its examples and there
3:35:58
rendering all of its examples and there are a total of 10,000 examples I am
3:36:00
are a total of 10,000 examples I am
3:36:00
are a total of 10,000 examples I am rendering them into this format um and
3:36:04
rendering them into this format um and
3:36:04
rendering them into this format um and so here at the end of this render
3:36:07
so here at the end of this render
3:36:07
so here at the end of this render example function you can see that I'm
3:36:09
example function you can see that I'm
3:36:09
example function you can see that I'm returning the
3:36:11
returning the
3:36:11
returning the tokens uh the tokens of this um 4xt
3:36:16
tokens uh the tokens of this um 4xt
3:36:16
tokens uh the tokens of this um 4xt uh array of Tokens The Mask which tells
3:36:19
uh array of Tokens The Mask which tells
3:36:19
uh array of Tokens The Mask which tells us which parts are the options and
3:36:21
us which parts are the options and
3:36:21
us which parts are the options and everything else is zero and the label
3:36:24
everything else is zero and the label
3:36:24
everything else is zero and the label that is the correct label and so that
3:36:26
that is the correct label and so that
3:36:26
that is the correct label and so that allows us to then iterate the examples
3:36:28
allows us to then iterate the examples
3:36:28
allows us to then iterate the examples and render them and I have an evaluate
3:36:30
and render them and I have an evaluate
3:36:30
and render them and I have an evaluate function here which can load a um gpt2
3:36:33
function here which can load a um gpt2
3:36:33
function here which can load a um gpt2 from huging face and it runs the eval
3:36:36
from huging face and it runs the eval
3:36:36
from huging face and it runs the eval here um and it basically just calculates
3:36:40
here um and it basically just calculates
3:36:40
here um and it basically just calculates uh just as I described it predicts the
3:36:42
uh just as I described it predicts the
3:36:42
uh just as I described it predicts the option that has the lowest or the
3:36:45
option that has the lowest or the
3:36:45
option that has the lowest or the highest prob ility and the way to do
3:36:47
highest prob ility and the way to do
3:36:47
highest prob ility and the way to do that actually is we can basically
3:36:48
that actually is we can basically
3:36:48
that actually is we can basically evaluate the cross entropy loss so we're
3:36:51
evaluate the cross entropy loss so we're
3:36:51
evaluate the cross entropy loss so we're basically evaluating the loss of
3:36:53
basically evaluating the loss of
3:36:53
basically evaluating the loss of predicting the next token in a sequence
3:36:55
predicting the next token in a sequence
3:36:55
predicting the next token in a sequence and then we're looking at the row that
3:36:57
and then we're looking at the row that
3:36:57
and then we're looking at the row that has the lowest average loss and that's
3:37:01
has the lowest average loss and that's
3:37:01
has the lowest average loss and that's the uh option that we pick as the
3:37:04
the uh option that we pick as the
3:37:04
the uh option that we pick as the prediction and then we do some stats and
3:37:06
prediction and then we do some stats and
3:37:06
prediction and then we do some stats and prints and stuff like that so that is a
3:37:08
prints and stuff like that so that is a
3:37:08
prints and stuff like that so that is a way to evaluate L swag now if you go up
3:37:11
way to evaluate L swag now if you go up
3:37:11
way to evaluate L swag now if you go up here I'm showing that for GPT 2124m if
3:37:14
here I'm showing that for GPT 2124m if
3:37:14
here I'm showing that for GPT 2124m if you run this script you're going to see
3:37:16
you run this script you're going to see
3:37:16
you run this script you're going to see that H swag gets
3:37:19
that H swag gets
3:37:19
that H swag gets 29.5% um so that's the performance we
3:37:22
29.5% um so that's the performance we
3:37:22
29.5% um so that's the performance we get here now remember that random Chan
3:37:23
get here now remember that random Chan
3:37:23
get here now remember that random Chan is 25% so we haven't gone too far and
3:37:27
is 25% so we haven't gone too far and
3:37:27
is 25% so we haven't gone too far and gpt2 XL which is the biggest the gpt2
3:37:31
gpt2 XL which is the biggest the gpt2
3:37:31
gpt2 XL which is the biggest the gpt2 gets all the way up to 49% roughly so uh
3:37:34
gets all the way up to 49% roughly so uh
3:37:34
gets all the way up to 49% roughly so uh these are pretty low values considering
3:37:36
these are pretty low values considering
3:37:36
these are pretty low values considering that today's state-ofthe-art is more
3:37:37
that today's state-ofthe-art is more
3:37:37
that today's state-ofthe-art is more like 95% uh so these are definitely
3:37:40
like 95% uh so these are definitely
3:37:40
like 95% uh so these are definitely older models by now and then there's one
3:37:42
older models by now and then there's one
3:37:42
older models by now and then there's one more thing called Uther harness which is
3:37:44
more thing called Uther harness which is
3:37:44
more thing called Uther harness which is a very piece of infrastructure for
3:37:46
a very piece of infrastructure for
3:37:46
a very piece of infrastructure for running evals for language models and
3:37:48
running evals for language models and
3:37:48
running evals for language models and they get slightly different numbers and
3:37:50
they get slightly different numbers and
3:37:50
they get slightly different numbers and I'm not 100% sure what the discrepancy
3:37:52
I'm not 100% sure what the discrepancy
3:37:52
I'm not 100% sure what the discrepancy is for these um it could be that they
3:37:54
is for these um it could be that they
3:37:54
is for these um it could be that they actually do the multiple choice uh
3:37:57
actually do the multiple choice uh
3:37:57
actually do the multiple choice uh instead of just the completions and that
3:37:59
instead of just the completions and that
3:37:59
instead of just the completions and that could be the um uh the discrepancy but
3:38:02
could be the um uh the discrepancy but
3:38:02
could be the um uh the discrepancy but I'm not 100% sure about that i' have to
3:38:04
I'm not 100% sure about that i' have to
3:38:04
I'm not 100% sure about that i' have to take a look but for now our script
3:38:06
take a look but for now our script
3:38:06
take a look but for now our script reports 2955 and so that is the number
3:38:08
reports 2955 and so that is the number
3:38:08
reports 2955 and so that is the number that we'd like to beat if we are
3:38:10
that we'd like to beat if we are
3:38:10
that we'd like to beat if we are training a GPD 2124m from scratch and
3:38:13
training a GPD 2124m from scratch and
3:38:13
training a GPD 2124m from scratch and ourselves um
3:38:16
ourselves um
3:38:16
ourselves um so now I'm going to go into actually
3:38:19
so now I'm going to go into actually
3:38:19
so now I'm going to go into actually incorporating this eval into our main
3:38:22
incorporating this eval into our main
3:38:22
incorporating this eval into our main training script and um and basically
3:38:26
training script and um and basically
3:38:26
training script and um and basically because we want to evaluate it in a
3:38:27
because we want to evaluate it in a
3:38:28
because we want to evaluate it in a periodic manner so that we can track H
3:38:30
periodic manner so that we can track H
3:38:30
periodic manner so that we can track H swag and how it evolves over time and
3:38:32
swag and how it evolves over time and
3:38:32
swag and how it evolves over time and see when when and if we cross uh this
3:38:36
see when when and if we cross uh this
3:38:36
see when when and if we cross uh this 2955 um sort of region so let's now walk
3:38:41
2955 um sort of region so let's now walk
3:38:41
2955 um sort of region so let's now walk through some of the changes to train
3:38:42
through some of the changes to train
3:38:42
through some of the changes to train gpt2 thatp the first thing I did here is
3:38:45
gpt2 thatp the first thing I did here is
3:38:45
gpt2 thatp the first thing I did here is I actually made use compile optional
3:38:47
I actually made use compile optional
3:38:47
I actually made use compile optional kind of and I disabled it by default and
3:38:51
kind of and I disabled it by default and
3:38:51
kind of and I disabled it by default and the problem with that is the problem
3:38:52
the problem with that is the problem
3:38:53
the problem with that is the problem with compile is that unfortunately it
3:38:55
with compile is that unfortunately it
3:38:55
with compile is that unfortunately it does make our code faster but it
3:38:56
does make our code faster but it
3:38:56
does make our code faster but it actually breaks the evaluation code and
3:38:58
actually breaks the evaluation code and
3:38:58
actually breaks the evaluation code and the sampling code it gives me a very
3:39:00
the sampling code it gives me a very
3:39:00
the sampling code it gives me a very gnarly message and I don't know why so
3:39:02
gnarly message and I don't know why so
3:39:02
gnarly message and I don't know why so hopefully by the time you get to the
3:39:04
hopefully by the time you get to the
3:39:04
hopefully by the time you get to the codebase when I put it up on GitHub uh
3:39:06
codebase when I put it up on GitHub uh
3:39:06
codebase when I put it up on GitHub uh we're going to fix that by then but for
3:39:07
we're going to fix that by then but for
3:39:07
we're going to fix that by then but for now I'm running without torch compile
3:39:09
now I'm running without torch compile
3:39:09
now I'm running without torch compile which is why you see this be a bit
3:39:11
which is why you see this be a bit
3:39:11
which is why you see this be a bit slower so we're running without torch
3:39:13
slower so we're running without torch
3:39:13
slower so we're running without torch compile I also create cre a log
3:39:15
compile I also create cre a log
3:39:15
compile I also create cre a log directory log where we can place our
3:39:18
directory log where we can place our
3:39:18
directory log where we can place our log.txt which will record the train loss
3:39:21
log.txt which will record the train loss
3:39:22
log.txt which will record the train loss validation loss and the H swag
3:39:23
validation loss and the H swag
3:39:23
validation loss and the H swag accuracies so a very simple text file
3:39:25
accuracies so a very simple text file
3:39:25
accuracies so a very simple text file and we're going to uh open for writing
3:39:28
and we're going to uh open for writing
3:39:28
and we're going to uh open for writing so that it sort of starts empty and then
3:39:30
so that it sort of starts empty and then
3:39:30
so that it sort of starts empty and then we're going to append to
3:39:32
we're going to append to
3:39:32
we're going to append to it I created a simple variable that um
3:39:36
it I created a simple variable that um
3:39:36
it I created a simple variable that um helps tell us when we have a last step
3:39:38
helps tell us when we have a last step
3:39:39
helps tell us when we have a last step and then basically periodically inside
3:39:40
and then basically periodically inside
3:39:40
and then basically periodically inside this Loop every 250th iteration or at
3:39:44
this Loop every 250th iteration or at
3:39:44
this Loop every 250th iteration or at the last step we're going to evaluate
3:39:46
the last step we're going to evaluate
3:39:46
the last step we're going to evaluate the validation loss and then every 250th
3:39:50
the validation loss and then every 250th
3:39:50
the validation loss and then every 250th iteration um we are going to evaluate H
3:39:53
iteration um we are going to evaluate H
3:39:53
iteration um we are going to evaluate H swag but only if we are not using
3:39:56
swag but only if we are not using
3:39:56
swag but only if we are not using compile because compile breaks it so I'm
3:39:59
compile because compile breaks it so I'm
3:39:59
compile because compile breaks it so I'm going to come back to this code for
3:40:01
going to come back to this code for
3:40:01
going to come back to this code for evaluating H swag in a second and then
3:40:04
evaluating H swag in a second and then
3:40:04
evaluating H swag in a second and then every 250th iteration as well we're also
3:40:06
every 250th iteration as well we're also
3:40:06
every 250th iteration as well we're also going to sample from the model and so
3:40:08
going to sample from the model and so
3:40:08
going to sample from the model and so you should recognize this as our ancient
3:40:10
you should recognize this as our ancient
3:40:10
you should recognize this as our ancient code from way back when we started the
3:40:12
code from way back when we started the
3:40:12
code from way back when we started the video and we're just sampling from the
3:40:13
video and we're just sampling from the
3:40:13
video and we're just sampling from the model
3:40:15
model
3:40:15
model and then finally here um these are if
3:40:18
and then finally here um these are if
3:40:18
and then finally here um these are if we're not after we validate sample and
3:40:21
we're not after we validate sample and
3:40:21
we're not after we validate sample and evaluate hell swag we actually do a
3:40:23
evaluate hell swag we actually do a
3:40:23
evaluate hell swag we actually do a training step here and so this is one
3:40:26
training step here and so this is one
3:40:26
training step here and so this is one step of uh training and you should be
3:40:28
step of uh training and you should be
3:40:28
step of uh training and you should be pretty familiar with all of what this
3:40:30
pretty familiar with all of what this
3:40:30
pretty familiar with all of what this does and at the end here once we get our
3:40:32
does and at the end here once we get our
3:40:32
does and at the end here once we get our training laws we write it to the file so
3:40:35
training laws we write it to the file so
3:40:35
training laws we write it to the file so the only thing that changed that I
3:40:36
the only thing that changed that I
3:40:37
the only thing that changed that I really added is this entire section for
3:40:38
really added is this entire section for
3:40:38
really added is this entire section for H swag eval and the way this works is
3:40:41
H swag eval and the way this works is
3:40:41
H swag eval and the way this works is I'm trying to get all the gpus to
3:40:43
I'm trying to get all the gpus to
3:40:43
I'm trying to get all the gpus to collaborate on the H swag and so we're
3:40:45
collaborate on the H swag and so we're
3:40:45
collaborate on the H swag and so we're iterating all the examples and then each
3:40:48
iterating all the examples and then each
3:40:48
iterating all the examples and then each process only picks the examples that
3:40:52
process only picks the examples that
3:40:52
process only picks the examples that assigned to it so we sort of take I and
3:40:54
assigned to it so we sort of take I and
3:40:54
assigned to it so we sort of take I and moded by the world size and we have to
3:40:56
moded by the world size and we have to
3:40:56
moded by the world size and we have to make it equal to rank otherwise we
3:40:58
make it equal to rank otherwise we
3:40:58
make it equal to rank otherwise we continue and then we render an example
3:41:01
continue and then we render an example
3:41:01
continue and then we render an example put it on the GPU we get the low jits
3:41:04
put it on the GPU we get the low jits
3:41:04
put it on the GPU we get the low jits then I create a helper function that
3:41:05
then I create a helper function that
3:41:05
then I create a helper function that helps us basically predict the option
3:41:07
helps us basically predict the option
3:41:08
helps us basically predict the option with the lowest loss so this comes here
3:41:10
with the lowest loss so this comes here
3:41:10
with the lowest loss so this comes here the prediction and then if it's correct
3:41:12
the prediction and then if it's correct
3:41:12
the prediction and then if it's correct we sort of keep count and then if
3:41:15
we sort of keep count and then if
3:41:15
we sort of keep count and then if multiple processes were collaborating on
3:41:17
multiple processes were collaborating on
3:41:17
multiple processes were collaborating on all this then we need to synchronize
3:41:18
all this then we need to synchronize
3:41:18
all this then we need to synchronize their stats and so the way one way to do
3:41:21
their stats and so the way one way to do
3:41:21
their stats and so the way one way to do that is to package up our statistics
3:41:23
that is to package up our statistics
3:41:23
that is to package up our statistics here into tensors which we can then call
3:41:26
here into tensors which we can then call
3:41:26
here into tensors which we can then call this. alberon and
3:41:29
this. alberon and
3:41:29
this. alberon and sum and then here we sort of um unwrap
3:41:33
sum and then here we sort of um unwrap
3:41:33
sum and then here we sort of um unwrap them from tensors so that we just have
3:41:35
them from tensors so that we just have
3:41:35
them from tensors so that we just have ins and then here the master process
3:41:37
ins and then here the master process
3:41:37
ins and then here the master process will print and log the hellis swag
3:41:39
will print and log the hellis swag
3:41:40
will print and log the hellis swag accuracy
3:41:41
accuracy
3:41:41
accuracy so that's kind of the that's kind of it
3:41:45
so that's kind of the that's kind of it
3:41:45
so that's kind of the that's kind of it and that's what I'm running right here
3:41:47
and that's what I'm running right here
3:41:47
and that's what I'm running right here so you see this optimization here and uh
3:41:50
so you see this optimization here and uh
3:41:50
so you see this optimization here and uh we just had a generation and this is
3:41:52
we just had a generation and this is
3:41:52
we just had a generation and this is Step 10,000 out of about 20,000 right so
3:41:55
Step 10,000 out of about 20,000 right so
3:41:55
Step 10,000 out of about 20,000 right so we are halfway done and these are the
3:41:58
we are halfway done and these are the
3:41:58
we are halfway done and these are the kinds of samples that uh we are getting
3:41:59
kinds of samples that uh we are getting
3:41:59
kinds of samples that uh we are getting at this stage so let's take a look hello
3:42:02
at this stage so let's take a look hello
3:42:02
at this stage so let's take a look hello I'm a language model so I'd like to use
3:42:04
I'm a language model so I'd like to use
3:42:04
I'm a language model so I'd like to use it to generate some kinds of output
3:42:07
it to generate some kinds of output
3:42:07
it to generate some kinds of output hello I'm a language model and I'm a
3:42:08
hello I'm a language model and I'm a
3:42:08
hello I'm a language model and I'm a developer for a lot of
3:42:10
developer for a lot of
3:42:10
developer for a lot of companies Al language
3:42:12
companies Al language
3:42:12
companies Al language model uh let's see if I can find fun
3:42:17
model uh let's see if I can find fun
3:42:17
model uh let's see if I can find fun one
3:42:28
um I don't know you can go through this
3:42:30
um I don't know you can go through this
3:42:30
um I don't know you can go through this yourself but certainly the predictions
3:42:32
yourself but certainly the predictions
3:42:32
yourself but certainly the predictions are getting less and less random uh it
3:42:34
are getting less and less random uh it
3:42:34
are getting less and less random uh it seems like the model is a little bit
3:42:35
seems like the model is a little bit
3:42:35
seems like the model is a little bit more self-aware and using language uh
3:42:38
more self-aware and using language uh
3:42:38
more self-aware and using language uh that is a bit
3:42:39
that is a bit
3:42:39
that is a bit more uh specific to it being language
3:42:43
more uh specific to it being language
3:42:43
more uh specific to it being language model hello I'm a language model and
3:42:45
model hello I'm a language model and
3:42:45
model hello I'm a language model and like how the language is used to
3:42:46
like how the language is used to
3:42:46
like how the language is used to communicate I'm a language model and I'm
3:42:48
communicate I'm a language model and I'm
3:42:48
communicate I'm a language model and I'm going to be speaking English and German
3:42:52
going to be speaking English and German
3:42:52
going to be speaking English and German okay I don't know so let's just wait
3:42:53
okay I don't know so let's just wait
3:42:53
okay I don't know so let's just wait until this optimization finishes and uh
3:42:56
until this optimization finishes and uh
3:42:56
until this optimization finishes and uh we'll see what kind of samples we get
3:42:57
we'll see what kind of samples we get
3:42:57
we'll see what kind of samples we get and we're also going to look at the
3:42:59
and we're also going to look at the
3:42:59
and we're also going to look at the train Val and the hway accuracy and see
3:43:03
train Val and the hway accuracy and see
3:43:03
train Val and the hway accuracy and see how we're doing with respect to
3:43:06
how we're doing with respect to
3:43:06
how we're doing with respect to gpt2 okay good morning so focusing For a
3:43:09
gpt2 okay good morning so focusing For a
3:43:09
gpt2 okay good morning so focusing For a Moment On The jupyter Notebook here on
3:43:11
Moment On The jupyter Notebook here on
3:43:11
Moment On The jupyter Notebook here on the right I created a new cell that
3:43:13
the right I created a new cell that
3:43:13
the right I created a new cell that basically allows us to visualize the the
3:43:15
basically allows us to visualize the the
3:43:15
basically allows us to visualize the the train Val and Hela and um the hel score
3:43:19
train Val and Hela and um the hel score
3:43:19
train Val and Hela and um the hel score and you can step through this it
3:43:21
and you can step through this it
3:43:21
and you can step through this it basically like parses the log file that
3:43:22
basically like parses the log file that
3:43:22
basically like parses the log file that we are writing and um a lot of this is
3:43:25
we are writing and um a lot of this is
3:43:25
we are writing and um a lot of this is just like boring ma plot lip code but
3:43:28
just like boring ma plot lip code but
3:43:28
just like boring ma plot lip code but basically this is what our optimization
3:43:30
basically this is what our optimization
3:43:30
basically this is what our optimization looks like
3:43:32
looks like
3:43:32
looks like so we ran for
3:43:38
19,731 billion tokens which is whoops oh
3:43:41
19,731 billion tokens which is whoops oh
3:43:41
19,731 billion tokens which is whoops oh my gosh which is one Epoch of the sample
3:43:44
my gosh which is one Epoch of the sample
3:43:44
my gosh which is one Epoch of the sample 10B of webd on the left we have the loss
3:43:48
10B of webd on the left we have the loss
3:43:48
10B of webd on the left we have the loss and the in blue we have the training
3:43:50
and the in blue we have the training
3:43:50
and the in blue we have the training loss in Orange we have the validation
3:43:52
loss in Orange we have the validation
3:43:52
loss in Orange we have the validation loss and red as a horizontal line we
3:43:55
loss and red as a horizontal line we
3:43:55
loss and red as a horizontal line we have the opening IG gpt2 124 M model
3:43:58
have the opening IG gpt2 124 M model
3:43:58
have the opening IG gpt2 124 M model checkpoint when it's just evaluated on
3:44:00
checkpoint when it's just evaluated on
3:44:00
checkpoint when it's just evaluated on the validation set of um of this fine
3:44:04
the validation set of um of this fine
3:44:04
the validation set of um of this fine web edu uh so you can see that we are
3:44:06
web edu uh so you can see that we are
3:44:06
web edu uh so you can see that we are surpassing this orange is below the red
3:44:09
surpassing this orange is below the red
3:44:09
surpassing this orange is below the red so we're surpassing the validation set
3:44:11
so we're surpassing the validation set
3:44:11
so we're surpassing the validation set of this data set and like I mentioned
3:44:13
of this data set and like I mentioned
3:44:13
of this data set and like I mentioned the data set distribution is very
3:44:15
the data set distribution is very
3:44:15
the data set distribution is very different from what gpt2 trained on so
3:44:16
different from what gpt2 trained on so
3:44:16
different from what gpt2 trained on so this is not an exactly fair comparison
3:44:19
this is not an exactly fair comparison
3:44:19
this is not an exactly fair comparison but it's a good cross check uh to uh to
3:44:22
but it's a good cross check uh to uh to
3:44:22
but it's a good cross check uh to uh to look at now we would ideally like
3:44:24
look at now we would ideally like
3:44:25
look at now we would ideally like something that is withheld and
3:44:27
something that is withheld and
3:44:27
something that is withheld and comparable and somewhat standard um and
3:44:30
comparable and somewhat standard um and
3:44:30
comparable and somewhat standard um and so for us that is helis swag and so on
3:44:33
so for us that is helis swag and so on
3:44:33
so for us that is helis swag and so on here we see the H swag progress we made
3:44:35
here we see the H swag progress we made
3:44:35
here we see the H swag progress we made from 25% all the way here in red we see
3:44:39
from 25% all the way here in red we see
3:44:39
from 25% all the way here in red we see the open gpt2 124 M model in red so it
3:44:44
the open gpt2 124 M model in red so it
3:44:44
the open gpt2 124 M model in red so it achieves this h bag here and the the
3:44:47
achieves this h bag here and the the
3:44:47
achieves this h bag here and the the gpt3 model 124 M which was trained on
3:44:50
gpt3 model 124 M which was trained on
3:44:50
gpt3 model 124 M which was trained on 300 billion tokens achieves green so
3:44:54
300 billion tokens achieves green so
3:44:54
300 billion tokens achieves green so that's over here so you see that we
3:44:56
that's over here so you see that we
3:44:56
that's over here so you see that we basically surpassed the gbt2 24m uh
3:45:00
basically surpassed the gbt2 24m uh
3:45:00
basically surpassed the gbt2 24m uh model right here uh which is uh really
3:45:03
model right here uh which is uh really
3:45:03
model right here uh which is uh really nice now interestingly we were able to
3:45:07
nice now interestingly we were able to
3:45:07
nice now interestingly we were able to do so with only training on 10 billion
3:45:08
do so with only training on 10 billion
3:45:08
do so with only training on 10 billion tokens while gpt2 was trained on 100
3:45:11
tokens while gpt2 was trained on 100
3:45:11
tokens while gpt2 was trained on 100 billion tokens so uh for some reason we
3:45:14
billion tokens so uh for some reason we
3:45:14
billion tokens so uh for some reason we were able to get away with significantly
3:45:16
were able to get away with significantly
3:45:16
were able to get away with significantly fewer tokens for training there are many
3:45:18
fewer tokens for training there are many
3:45:18
fewer tokens for training there are many possibilities to as to why we could
3:45:21
possibilities to as to why we could
3:45:21
possibilities to as to why we could match or surpass this accuracy um with
3:45:24
match or surpass this accuracy um with
3:45:24
match or surpass this accuracy um with only 10 million training so number one
3:45:27
only 10 million training so number one
3:45:27
only 10 million training so number one um it could be that opening gbt2 was
3:45:30
um it could be that opening gbt2 was
3:45:30
um it could be that opening gbt2 was trained on a much wider data
3:45:32
trained on a much wider data
3:45:32
trained on a much wider data distribution so in particular fine web
3:45:34
distribution so in particular fine web
3:45:34
distribution so in particular fine web edu is all English it's not multilingual
3:45:38
edu is all English it's not multilingual
3:45:38
edu is all English it's not multilingual and there's not that much math and code
3:45:40
and there's not that much math and code
3:45:40
and there's not that much math and code um and so math and code and multilingual
3:45:43
um and so math and code and multilingual
3:45:43
um and so math and code and multilingual could have been stealing capacity from
3:45:45
could have been stealing capacity from
3:45:45
could have been stealing capacity from the original gpt2 model and um basically
3:45:50
the original gpt2 model and um basically
3:45:50
the original gpt2 model and um basically that could be partially the reason why
3:45:52
that could be partially the reason why
3:45:52
that could be partially the reason why uh this is not working out there's many
3:45:54
uh this is not working out there's many
3:45:54
uh this is not working out there's many other reasons um so for example the H
3:45:57
other reasons um so for example the H
3:45:57
other reasons um so for example the H swag eval is fairly old uh maybe 5 years
3:45:59
swag eval is fairly old uh maybe 5 years
3:45:59
swag eval is fairly old uh maybe 5 years or so it is possible that aspects of H
3:46:02
or so it is possible that aspects of H
3:46:02
or so it is possible that aspects of H swag in some way or even identically
3:46:04
swag in some way or even identically
3:46:04
swag in some way or even identically have made it into the training Set uh of
3:46:07
have made it into the training Set uh of
3:46:07
have made it into the training Set uh of fine web we don't know for sure but if
3:46:10
fine web we don't know for sure but if
3:46:10
fine web we don't know for sure but if that was the case then we are basically
3:46:11
that was the case then we are basically
3:46:11
that was the case then we are basically looking at the training curve instead of
3:46:12
looking at the training curve instead of
3:46:12
looking at the training curve instead of the validation curve so long story short
3:46:15
the validation curve so long story short
3:46:15
the validation curve so long story short this is not a perfect eval and there's
3:46:16
this is not a perfect eval and there's
3:46:16
this is not a perfect eval and there's some caveats here uh but at least we
3:46:18
some caveats here uh but at least we
3:46:19
some caveats here uh but at least we have some confidence that that we're not
3:46:20
have some confidence that that we're not
3:46:20
have some confidence that that we're not doing something completely wrong and
3:46:23
doing something completely wrong and
3:46:23
doing something completely wrong and um and uh it's probably the case that
3:46:26
um and uh it's probably the case that
3:46:26
um and uh it's probably the case that when people try to create these data
3:46:27
when people try to create these data
3:46:27
when people try to create these data sets they try to make sure that test
3:46:29
sets they try to make sure that test
3:46:29
sets they try to make sure that test sets that are very common are not part
3:46:31
sets that are very common are not part
3:46:31
sets that are very common are not part of the training set for example uh when
3:46:33
of the training set for example uh when
3:46:33
of the training set for example uh when hugging face created the fine web BDU
3:46:35
hugging face created the fine web BDU
3:46:35
hugging face created the fine web BDU they use H swag as an eval so I would
3:46:37
they use H swag as an eval so I would
3:46:37
they use H swag as an eval so I would hope that they make sure that they D
3:46:39
hope that they make sure that they D
3:46:39
hope that they make sure that they D duplicate and that there's no hella swag
3:46:41
duplicate and that there's no hella swag
3:46:41
duplicate and that there's no hella swag in the training set but we can't be sure
3:46:44
in the training set but we can't be sure
3:46:45
in the training set but we can't be sure uh the other thing I wanted to address
3:46:46
uh the other thing I wanted to address
3:46:46
uh the other thing I wanted to address briefly is look at this loss curve this
3:46:48
briefly is look at this loss curve this
3:46:48
briefly is look at this loss curve this looks really this looks really wrong
3:46:50
looks really this looks really wrong
3:46:50
looks really this looks really wrong here I don't actually know 100% what
3:46:52
here I don't actually know 100% what
3:46:52
here I don't actually know 100% what this is and I suspect it's because the
3:46:55
this is and I suspect it's because the
3:46:55
this is and I suspect it's because the uh 10 billion sample of fine web edu was
3:46:58
uh 10 billion sample of fine web edu was
3:46:58
uh 10 billion sample of fine web edu was not properly shuffled um and there's
3:47:01
not properly shuffled um and there's
3:47:01
not properly shuffled um and there's some issue here uh with the data that I
3:47:04
some issue here uh with the data that I
3:47:04
some issue here uh with the data that I don't fully understand yet and there's
3:47:06
don't fully understand yet and there's
3:47:06
don't fully understand yet and there's some weird periodicity to it um and
3:47:08
some weird periodicity to it um and
3:47:08
some weird periodicity to it um and because we are in a very lazy way sort
3:47:10
because we are in a very lazy way sort
3:47:10
because we are in a very lazy way sort of serializing all the tokens and just
3:47:12
of serializing all the tokens and just
3:47:12
of serializing all the tokens and just iterating all them from scratch without
3:47:13
iterating all them from scratch without
3:47:14
iterating all them from scratch without doing any permutation or any random
3:47:16
doing any permutation or any random
3:47:16
doing any permutation or any random sampling ourselves I think we're
3:47:18
sampling ourselves I think we're
3:47:18
sampling ourselves I think we're inheriting some of the ordering that
3:47:21
inheriting some of the ordering that
3:47:21
inheriting some of the ordering that they have in the data set so uh this is
3:47:24
they have in the data set so uh this is
3:47:24
they have in the data set so uh this is not ideal but hopefully by the time you
3:47:26
not ideal but hopefully by the time you
3:47:26
not ideal but hopefully by the time you get to this repo uh some of these things
3:47:28
get to this repo uh some of these things
3:47:28
get to this repo uh some of these things by the way will hopefully be fixed and I
3:47:32
by the way will hopefully be fixed and I
3:47:32
by the way will hopefully be fixed and I will release this build n GPT repo and
3:47:35
will release this build n GPT repo and
3:47:35
will release this build n GPT repo and right now it looks a little ugly and
3:47:37
right now it looks a little ugly and
3:47:37
right now it looks a little ugly and preliminary uh so hopefully by the time
3:47:39
preliminary uh so hopefully by the time
3:47:39
preliminary uh so hopefully by the time you get here it's nicer but down here
3:47:41
you get here it's nicer but down here
3:47:41
you get here it's nicer but down here I'm going to show aada and I'm going to
3:47:44
I'm going to show aada and I'm going to
3:47:44
I'm going to show aada and I'm going to talk about about some of the things that
3:47:45
talk about about some of the things that
3:47:45
talk about about some of the things that happened after the video and I expect
3:47:48
happened after the video and I expect
3:47:48
happened after the video and I expect that we will have fixed uh the small
3:47:50
that we will have fixed uh the small
3:47:50
that we will have fixed uh the small issue uh but for now basically this
3:47:52
issue uh but for now basically this
3:47:52
issue uh but for now basically this shows that uh our training is not uh
3:47:55
shows that uh our training is not uh
3:47:55
shows that uh our training is not uh completely wrong and it shows that uh
3:47:57
completely wrong and it shows that uh
3:47:58
completely wrong and it shows that uh we're able to surpass the accuracy with
3:48:00
we're able to surpass the accuracy with
3:48:00
we're able to surpass the accuracy with only 10x the token budget um and
3:48:03
only 10x the token budget um and
3:48:03
only 10x the token budget um and possibly it could be also that the data
3:48:05
possibly it could be also that the data
3:48:05
possibly it could be also that the data set may have improved so uh the original
3:48:08
set may have improved so uh the original
3:48:08
set may have improved so uh the original uh gpt2 data set was web text it's
3:48:11
uh gpt2 data set was web text it's
3:48:11
uh gpt2 data set was web text it's possible that not a lot of care and
3:48:12
possible that not a lot of care and
3:48:12
possible that not a lot of care and attention went into the data set this
3:48:14
attention went into the data set this
3:48:14
attention went into the data set this was very early in llms whereas now
3:48:16
was very early in llms whereas now
3:48:17
was very early in llms whereas now there's a lot more scrutiny on good
3:48:18
there's a lot more scrutiny on good
3:48:18
there's a lot more scrutiny on good practices around uh D duplication
3:48:20
practices around uh D duplication
3:48:20
practices around uh D duplication filtering uh quality filtering and so on
3:48:23
filtering uh quality filtering and so on
3:48:23
filtering uh quality filtering and so on and it's possible that the data that
3:48:24
and it's possible that the data that
3:48:24
and it's possible that the data that we're training on is just of higher
3:48:25
we're training on is just of higher
3:48:25
we're training on is just of higher quality per token and that could be
3:48:27
quality per token and that could be
3:48:27
quality per token and that could be giving us a boost as well so a number of
3:48:30
giving us a boost as well so a number of
3:48:30
giving us a boost as well so a number of cave has to think about but for now uh
3:48:32
cave has to think about but for now uh
3:48:32
cave has to think about but for now uh we're pretty happy with this um and yeah
3:48:36
we're pretty happy with this um and yeah
3:48:36
we're pretty happy with this um and yeah now the next thing I was interested in
3:48:37
now the next thing I was interested in
3:48:37
now the next thing I was interested in is as you see it's a morning now so
3:48:39
is as you see it's a morning now so
3:48:39
is as you see it's a morning now so there was an overnight and I wanted to
3:48:41
there was an overnight and I wanted to
3:48:41
there was an overnight and I wanted to basically see how far I could push the
3:48:43
basically see how far I could push the
3:48:43
basically see how far I could push the result so uh to do an overnight run I
3:48:46
result so uh to do an overnight run I
3:48:46
result so uh to do an overnight run I basically did instead of one Epoch which
3:48:48
basically did instead of one Epoch which
3:48:48
basically did instead of one Epoch which took roughly two hours I just did a
3:48:50
took roughly two hours I just did a
3:48:50
took roughly two hours I just did a times four so that that would take eight
3:48:52
times four so that that would take eight
3:48:52
times four so that that would take eight hours while I was sleeping and so we did
3:48:54
hours while I was sleeping and so we did
3:48:54
hours while I was sleeping and so we did four Epoch or roughly 40 billion uh
3:48:56
four Epoch or roughly 40 billion uh
3:48:56
four Epoch or roughly 40 billion uh tokens of training and I was trying to
3:48:59
tokens of training and I was trying to
3:48:59
tokens of training and I was trying to see how far we could get um and so this
3:49:01
see how far we could get um and so this
3:49:01
see how far we could get um and so this was the only change and I reran the
3:49:03
was the only change and I reran the
3:49:03
was the only change and I reran the script and when I point uh and read the
3:49:05
script and when I point uh and read the
3:49:05
script and when I point uh and read the log file at uh at the 40b uh this is
3:49:08
log file at uh at the 40b uh this is
3:49:08
log file at uh at the 40b uh this is what the curve look
3:49:10
what the curve look
3:49:10
what the curve look like okay so to narrate this number one
3:49:13
like okay so to narrate this number one
3:49:13
like okay so to narrate this number one we are seeing this issue here here with
3:49:15
we are seeing this issue here here with
3:49:15
we are seeing this issue here here with the periodicity through the different
3:49:17
the periodicity through the different
3:49:17
the periodicity through the different Epoch and something really weird with
3:49:19
Epoch and something really weird with
3:49:19
Epoch and something really weird with the fine web edu data set and that is to
3:49:22
the fine web edu data set and that is to
3:49:22
the fine web edu data set and that is to be determined uh but otherwise we are
3:49:25
be determined uh but otherwise we are
3:49:25
be determined uh but otherwise we are seeing that the H swag actually went up
3:49:27
seeing that the H swag actually went up
3:49:27
seeing that the H swag actually went up by a lot and we almost we almost made it
3:49:31
by a lot and we almost we almost made it
3:49:31
by a lot and we almost we almost made it uh to the GPT 324m accuracy uh up here
3:49:35
uh to the GPT 324m accuracy uh up here
3:49:35
uh to the GPT 324m accuracy uh up here uh but not quite so uh it's too bad that
3:49:37
uh but not quite so uh it's too bad that
3:49:37
uh but not quite so uh it's too bad that I didn't sleep slightly longer um and uh
3:49:41
I didn't sleep slightly longer um and uh
3:49:41
I didn't sleep slightly longer um and uh I think if this was an uh five Epoch run
3:49:44
I think if this was an uh five Epoch run
3:49:44
I think if this was an uh five Epoch run we may have gotten here now one thing to
3:49:47
we may have gotten here now one thing to
3:49:47
we may have gotten here now one thing to point out is that if you're doing multi
3:49:49
point out is that if you're doing multi
3:49:49
point out is that if you're doing multi Epoch runs uh we're not actually being
3:49:51
Epoch runs uh we're not actually being
3:49:51
Epoch runs uh we're not actually being very careful in our data loader and
3:49:53
very careful in our data loader and
3:49:53
very careful in our data loader and we're not um I this data loader goes
3:49:56
we're not um I this data loader goes
3:49:56
we're not um I this data loader goes through the data in exactly the same
3:49:59
through the data in exactly the same
3:49:59
through the data in exactly the same format and exactly the same order and
3:50:01
format and exactly the same order and
3:50:01
format and exactly the same order and this is kind of suboptimal and you would
3:50:03
this is kind of suboptimal and you would
3:50:03
this is kind of suboptimal and you would want to look into extensions where you
3:50:05
want to look into extensions where you
3:50:05
want to look into extensions where you actually permute the data uh randomly
3:50:08
actually permute the data uh randomly
3:50:08
actually permute the data uh randomly you permute the documents around in
3:50:10
you permute the documents around in
3:50:10
you permute the documents around in Every Single Shard on every single new
3:50:12
Every Single Shard on every single new
3:50:12
Every Single Shard on every single new Epoch um and po even permute the
3:50:16
Epoch um and po even permute the
3:50:16
Epoch um and po even permute the shards and that would go a long way into
3:50:18
shards and that would go a long way into
3:50:18
shards and that would go a long way into decreasing the pricity and it's also
3:50:20
decreasing the pricity and it's also
3:50:20
decreasing the pricity and it's also better for the optimization so that
3:50:22
better for the optimization so that
3:50:22
better for the optimization so that you're not seeing things ident in the
3:50:23
you're not seeing things ident in the
3:50:23
you're not seeing things ident in the identical format and you're introducing
3:50:25
identical format and you're introducing
3:50:25
identical format and you're introducing some of the some uh Randomness in how
3:50:27
some of the some uh Randomness in how
3:50:28
some of the some uh Randomness in how the documents follow each other because
3:50:29
the documents follow each other because
3:50:29
the documents follow each other because you have to remember that in every
3:50:31
you have to remember that in every
3:50:31
you have to remember that in every single row these documents follow each
3:50:33
single row these documents follow each
3:50:33
single row these documents follow each other and then there's the end of text
3:50:34
other and then there's the end of text
3:50:34
other and then there's the end of text token and then the next document so the
3:50:36
token and then the next document so the
3:50:36
token and then the next document so the documents are currently glued together
3:50:39
documents are currently glued together
3:50:39
documents are currently glued together in the exact same identical manner but
3:50:41
in the exact same identical manner but
3:50:41
in the exact same identical manner but we actually want to break break up the
3:50:43
we actually want to break break up the
3:50:43
we actually want to break break up the documents and shuffle them around
3:50:45
documents and shuffle them around
3:50:45
documents and shuffle them around because the order of the documents
3:50:46
because the order of the documents
3:50:46
because the order of the documents shouldn't matter and they shouldn't um
3:50:49
shouldn't matter and they shouldn't um
3:50:49
shouldn't matter and they shouldn't um basically we want to break up that
3:50:50
basically we want to break up that
3:50:50
basically we want to break up that dependence because it's a kind of a
3:50:51
dependence because it's a kind of a
3:50:51
dependence because it's a kind of a spous correlation and so our data lad is
3:50:54
spous correlation and so our data lad is
3:50:54
spous correlation and so our data lad is not currently doing that and that's one
3:50:56
not currently doing that and that's one
3:50:56
not currently doing that and that's one Improvement uh you could think of
3:50:58
Improvement uh you could think of
3:50:58
Improvement uh you could think of making um the other thing to point out
3:51:01
making um the other thing to point out
3:51:01
making um the other thing to point out is we're almost matching gpt3 accuracy
3:51:03
is we're almost matching gpt3 accuracy
3:51:03
is we're almost matching gpt3 accuracy with only 40 billion tokens gpt3 trained
3:51:06
with only 40 billion tokens gpt3 trained
3:51:06
with only 40 billion tokens gpt3 trained on 300 billion tokens so again we're
3:51:08
on 300 billion tokens so again we're
3:51:08
on 300 billion tokens so again we're seeing about a 10x um Improvement here
3:51:11
seeing about a 10x um Improvement here
3:51:11
seeing about a 10x um Improvement here with respect to learning efficiency uh
3:51:14
with respect to learning efficiency uh
3:51:14
with respect to learning efficiency uh the other thing I wanted to and I don't
3:51:16
the other thing I wanted to and I don't
3:51:16
the other thing I wanted to and I don't actually know exactly what to attribute
3:51:18
actually know exactly what to attribute
3:51:18
actually know exactly what to attribute this to other than some of the things
3:51:19
this to other than some of the things
3:51:19
this to other than some of the things that I already mentioned previously for
3:51:21
that I already mentioned previously for
3:51:21
that I already mentioned previously for the previous run uh the other thing I
3:51:23
the previous run uh the other thing I
3:51:23
the previous run uh the other thing I wanted to briefly mention is uh the max
3:51:26
wanted to briefly mention is uh the max
3:51:26
wanted to briefly mention is uh the max LR here I saw some people already play
3:51:29
LR here I saw some people already play
3:51:29
LR here I saw some people already play with this a little bit in a previous
3:51:30
with this a little bit in a previous
3:51:31
with this a little bit in a previous related repository um and it turns out
3:51:33
related repository um and it turns out
3:51:33
related repository um and it turns out that you can actually almost like three
3:51:35
that you can actually almost like three
3:51:35
that you can actually almost like three xas so it's possible that the maximum
3:51:37
xas so it's possible that the maximum
3:51:37
xas so it's possible that the maximum learning rate can be a lot higher and
3:51:39
learning rate can be a lot higher and
3:51:39
learning rate can be a lot higher and for some reason the gpt3 hyper
3:51:40
for some reason the gpt3 hyper
3:51:40
for some reason the gpt3 hyper parameters that we are inheriting are
3:51:42
parameters that we are inheriting are
3:51:42
parameters that we are inheriting are actually extremely conservative and you
3:51:44
actually extremely conservative and you
3:51:44
actually extremely conservative and you can actually get away with a Higher
3:51:45
can actually get away with a Higher
3:51:45
can actually get away with a Higher Learning rate and it would train faster
3:51:47
Learning rate and it would train faster
3:51:47
Learning rate and it would train faster so a lot of these hyper parameters um
3:51:50
so a lot of these hyper parameters um
3:51:50
so a lot of these hyper parameters um are quite tunable and feel free to play
3:51:52
are quite tunable and feel free to play
3:51:52
are quite tunable and feel free to play with them and they're probably not set
3:51:54
with them and they're probably not set
3:51:54
with them and they're probably not set precisely correctly and um it's possible
3:51:59
precisely correctly and um it's possible
3:51:59
precisely correctly and um it's possible that you can get away with doing this
3:52:01
that you can get away with doing this
3:52:01
that you can get away with doing this basically and if you wanted to exactly
3:52:03
basically and if you wanted to exactly
3:52:03
basically and if you wanted to exactly be faithful to gpt3 you would also want
3:52:07
be faithful to gpt3 you would also want
3:52:07
be faithful to gpt3 you would also want to make the following difference you'd
3:52:10
to make the following difference you'd
3:52:10
to make the following difference you'd want to come here and the sequence
3:52:11
want to come here and the sequence
3:52:11
want to come here and the sequence length of gpt3 is 2x it's 20 48 instead
3:52:15
length of gpt3 is 2x it's 20 48 instead
3:52:15
length of gpt3 is 2x it's 20 48 instead of 1,24 so you would come here change
3:52:17
of 1,24 so you would come here change
3:52:17
of 1,24 so you would come here change this to 248 for T and then if you want
3:52:20
this to 248 for T and then if you want
3:52:20
this to 248 for T and then if you want the exact same number of tokens uh half
3:52:22
the exact same number of tokens uh half
3:52:22
the exact same number of tokens uh half a million per iteration or per step you
3:52:25
a million per iteration or per step you
3:52:25
a million per iteration or per step you want to then decrease this to 32 so they
3:52:28
want to then decrease this to 32 so they
3:52:28
want to then decrease this to 32 so they still multiply to half a mil so that
3:52:31
still multiply to half a mil so that
3:52:31
still multiply to half a mil so that would give your model sequence length
3:52:33
would give your model sequence length
3:52:33
would give your model sequence length equal to that of gpt3 and in that case
3:52:36
equal to that of gpt3 and in that case
3:52:36
equal to that of gpt3 and in that case basically the
3:52:37
basically the
3:52:37
basically the um the models would be roughly identical
3:52:40
um the models would be roughly identical
3:52:40
um the models would be roughly identical as far as I'm as far as I'm aware
3:52:42
as far as I'm as far as I'm aware
3:52:42
as far as I'm as far as I'm aware because again gpt2 and gpt3 are very
3:52:44
because again gpt2 and gpt3 are very
3:52:44
because again gpt2 and gpt3 are very very similar models now we can also look
3:52:47
very similar models now we can also look
3:52:47
very similar models now we can also look at some of the samples here from the
3:52:48
at some of the samples here from the
3:52:48
at some of the samples here from the model that was trained overnight so this
3:52:51
model that was trained overnight so this
3:52:51
model that was trained overnight so this is
3:52:52
is
3:52:52
is the optimization and you see that here
3:52:55
the optimization and you see that here
3:52:55
the optimization and you see that here we stepped all the way to
3:52:57
we stepped all the way to
3:52:57
we stepped all the way to 76290 also or so and these are the hos
3:53:02
76290 also or so and these are the hos
3:53:02
76290 also or so and these are the hos mag we achieved was 33.2 4 and these are
3:53:06
mag we achieved was 33.2 4 and these are
3:53:06
mag we achieved was 33.2 4 and these are some of the samples from the model and
3:53:08
some of the samples from the model and
3:53:08
some of the samples from the model and you can see that if you read through
3:53:10
you can see that if you read through
3:53:10
you can see that if you read through this and pause the video briefly you can
3:53:11
this and pause the video briefly you can
3:53:11
this and pause the video briefly you can see that they are a lot more coherent uh
3:53:14
see that they are a lot more coherent uh
3:53:14
see that they are a lot more coherent uh so
3:53:15
so
3:53:15
so um and they're actually addressing the
3:53:17
um and they're actually addressing the
3:53:17
um and they're actually addressing the fact that it's a language model almost
3:53:21
fact that it's a language model almost
3:53:21
fact that it's a language model almost so uh hello I'm a language model and I
3:53:24
so uh hello I'm a language model and I
3:53:24
so uh hello I'm a language model and I try to be as accurate as
3:53:27
try to be as accurate as
3:53:27
try to be as accurate as possible um I'm a language model not a
3:53:29
possible um I'm a language model not a
3:53:29
possible um I'm a language model not a programming
3:53:31
programming
3:53:31
programming language I know how to communicate uh I
3:53:34
language I know how to communicate uh I
3:53:34
language I know how to communicate uh I use
3:53:35
use
3:53:35
use Python
3:53:37
Python
3:53:37
Python um I don't know if you pause this and
3:53:40
um I don't know if you pause this and
3:53:40
um I don't know if you pause this and look at it and then compare it to the
3:53:41
look at it and then compare it to the
3:53:41
look at it and then compare it to the one to the model that was only trained
3:53:43
one to the model that was only trained
3:53:43
one to the model that was only trained for 10 billion uh you will see that
3:53:45
for 10 billion uh you will see that
3:53:45
for 10 billion uh you will see that these are a lot more coherent and you
3:53:47
these are a lot more coherent and you
3:53:47
these are a lot more coherent and you can play with this uh
3:53:48
can play with this uh
3:53:48
can play with this uh yourself one more thing I added to The
3:53:50
yourself one more thing I added to The
3:53:50
yourself one more thing I added to The Code by the way is this chunk of code
3:53:52
Code by the way is this chunk of code
3:53:52
Code by the way is this chunk of code here so basically right after we
3:53:54
here so basically right after we
3:53:54
here so basically right after we evaluate the validation loss if we are
3:53:56
evaluate the validation loss if we are
3:53:56
evaluate the validation loss if we are the master process in addition to
3:53:58
the master process in addition to
3:53:58
the master process in addition to logging the validation loss every 5,000
3:54:01
logging the validation loss every 5,000
3:54:01
logging the validation loss every 5,000 steps we're also going to save the
3:54:02
steps we're also going to save the
3:54:02
steps we're also going to save the checkpoint which is really just the
3:54:04
checkpoint which is really just the
3:54:04
checkpoint which is really just the state dictionary of the model and so
3:54:07
state dictionary of the model and so
3:54:07
state dictionary of the model and so checkpointing is nice just because uh
3:54:09
checkpointing is nice just because uh
3:54:09
checkpointing is nice just because uh you can save the model and later you can
3:54:11
you can save the model and later you can
3:54:11
you can save the model and later you can uh use it in some way if you wanted to
3:54:13
uh use it in some way if you wanted to
3:54:13
uh use it in some way if you wanted to resume the optimiz ation then in
3:54:15
resume the optimiz ation then in
3:54:15
resume the optimiz ation then in addition to saving the model we have to
3:54:17
addition to saving the model we have to
3:54:17
addition to saving the model we have to also save the optimizer State dict
3:54:20
also save the optimizer State dict
3:54:20
also save the optimizer State dict because remember that the optimizer has
3:54:21
because remember that the optimizer has
3:54:21
because remember that the optimizer has a few additional buffers because of adom
3:54:24
a few additional buffers because of adom
3:54:24
a few additional buffers because of adom so it's got the m and V and uh you need
3:54:28
so it's got the m and V and uh you need
3:54:28
so it's got the m and V and uh you need to also resume the optimizer properly
3:54:30
to also resume the optimizer properly
3:54:30
to also resume the optimizer properly you have to be careful with your RNG
3:54:31
you have to be careful with your RNG
3:54:31
you have to be careful with your RNG seeds uh random number generators and so
3:54:33
seeds uh random number generators and so
3:54:33
seeds uh random number generators and so on so if you wanted to exactly be able
3:54:35
on so if you wanted to exactly be able
3:54:35
on so if you wanted to exactly be able to resume optimization you have to think
3:54:37
to resume optimization you have to think
3:54:37
to resume optimization you have to think through the state of the of the training
3:54:40
through the state of the of the training
3:54:40
through the state of the of the training process but if you just want to save the
3:54:41
process but if you just want to save the
3:54:41
process but if you just want to save the model this is how you would do it and
3:54:43
model this is how you would do it and
3:54:43
model this is how you would do it and one one nice reason why you might want
3:54:45
one one nice reason why you might want
3:54:45
one one nice reason why you might want to do this is because you may want to
3:54:47
to do this is because you may want to
3:54:47
to do this is because you may want to evaluate the model a lot more carefully
3:54:50
evaluate the model a lot more carefully
3:54:50
evaluate the model a lot more carefully so here we are only kind of like winging
3:54:52
so here we are only kind of like winging
3:54:52
so here we are only kind of like winging the hell swag eval but you may want to
3:54:54
the hell swag eval but you may want to
3:54:54
the hell swag eval but you may want to use something um nicer like for example
3:54:57
use something um nicer like for example
3:54:57
use something um nicer like for example the Luther uh Luther evaluation hardness
3:55:01
the Luther uh Luther evaluation hardness
3:55:01
the Luther uh Luther evaluation hardness evaluation hardness hardness um so this
3:55:06
evaluation hardness hardness um so this
3:55:06
evaluation hardness hardness um so this is a way to also evaluate language
3:55:08
is a way to also evaluate language
3:55:08
is a way to also evaluate language models and um so it's possible that um
3:55:13
models and um so it's possible that um
3:55:13
models and um so it's possible that um you may want to use basically different
3:55:15
you may want to use basically different
3:55:15
you may want to use basically different infrastructure to more thoroughly
3:55:17
infrastructure to more thoroughly
3:55:17
infrastructure to more thoroughly evaluate the models on different um
3:55:20
evaluate the models on different um
3:55:20
evaluate the models on different um evaluations and compare it to the
3:55:21
evaluations and compare it to the
3:55:21
evaluations and compare it to the opening gbt2 model on many other um
3:55:25
opening gbt2 model on many other um
3:55:25
opening gbt2 model on many other um tasks like for example that involve math
3:55:26
tasks like for example that involve math
3:55:26
tasks like for example that involve math code or different languages and so on so
3:55:29
code or different languages and so on so
3:55:29
code or different languages and so on so this is a nice functionality to have as
3:55:30
this is a nice functionality to have as
3:55:30
this is a nice functionality to have as well
3:55:32
well
3:55:32
well um and then the other thing I wanted to
3:55:34
um and then the other thing I wanted to
3:55:34
um and then the other thing I wanted to mention is that everything we've built
3:55:36
mention is that everything we've built
3:55:36
mention is that everything we've built here this is only the pre-training step
3:55:39
here this is only the pre-training step
3:55:39
here this is only the pre-training step so um the GPT here is a it dreams
3:55:42
so um the GPT here is a it dreams
3:55:42
so um the GPT here is a it dreams documents it just predicts the next to
3:55:44
documents it just predicts the next to
3:55:44
documents it just predicts the next to you can't talk to it like you can talk
3:55:46
you can't talk to it like you can talk
3:55:46
you can't talk to it like you can talk to chat GPT uh chat GPT if you wanted to
3:55:49
to chat GPT uh chat GPT if you wanted to
3:55:49
to chat GPT uh chat GPT if you wanted to talk to the model we have to fine-tune
3:55:51
talk to the model we have to fine-tune
3:55:51
talk to the model we have to fine-tune it into the chat format and it's not
3:55:54
it into the chat format and it's not
3:55:54
it into the chat format and it's not actually like that complicated if you're
3:55:55
actually like that complicated if you're
3:55:55
actually like that complicated if you're looking at supervised fine-tuning or sft
3:55:58
looking at supervised fine-tuning or sft
3:55:58
looking at supervised fine-tuning or sft really what that means is we're just
3:55:59
really what that means is we're just
3:55:59
really what that means is we're just swapping out a data set into a data set
3:56:01
swapping out a data set into a data set
3:56:01
swapping out a data set into a data set that is a lot more conversational and
3:56:03
that is a lot more conversational and
3:56:03
that is a lot more conversational and there's a user assistant user assistant
3:56:04
there's a user assistant user assistant
3:56:04
there's a user assistant user assistant kind of structure and we just fine-tune
3:56:06
kind of structure and we just fine-tune
3:56:06
kind of structure and we just fine-tune on it and then we um we basically fill
3:56:09
on it and then we um we basically fill
3:56:09
on it and then we um we basically fill in the user tokens and we sample the
3:56:11
in the user tokens and we sample the
3:56:11
in the user tokens and we sample the assistant tokens it's not a lot more
3:56:13
assistant tokens it's not a lot more
3:56:13
assistant tokens it's not a lot more deeper than that uh but basically we
3:56:15
deeper than that uh but basically we
3:56:15
deeper than that uh but basically we swap out the data set and continue
3:56:17
swap out the data set and continue
3:56:17
swap out the data set and continue training uh but for now we're going to
3:56:19
training uh but for now we're going to
3:56:19
training uh but for now we're going to stop at uh pre-training one more thing
3:56:21
stop at uh pre-training one more thing
3:56:21
stop at uh pre-training one more thing that I wanted to briefly show you is
3:56:23
that I wanted to briefly show you is
3:56:23
that I wanted to briefly show you is that of course what we've built up today
3:56:25
that of course what we've built up today
3:56:25
that of course what we've built up today was building towards nanog GPT which is
3:56:27
was building towards nanog GPT which is
3:56:27
was building towards nanog GPT which is this repository from earlier uh but also
3:56:30
this repository from earlier uh but also
3:56:30
this repository from earlier uh but also there's actually another nanog GPT
3:56:32
there's actually another nanog GPT
3:56:32
there's actually another nanog GPT implementation and it's hiding in a more
3:56:34
implementation and it's hiding in a more
3:56:34
implementation and it's hiding in a more recent project that I've been working on
3:56:36
recent project that I've been working on
3:56:36
recent project that I've been working on called llm Doc and lm. C is a pure Cuda
3:56:41
called llm Doc and lm. C is a pure Cuda
3:56:41
called llm Doc and lm. C is a pure Cuda implementation of gpt2 or gpt3 training
3:56:44
implementation of gpt2 or gpt3 training
3:56:44
implementation of gpt2 or gpt3 training and it just directly uses uh Cuda and is
3:56:47
and it just directly uses uh Cuda and is
3:56:47
and it just directly uses uh Cuda and is written as Cuda now the nanog gbt here
3:56:51
written as Cuda now the nanog gbt here
3:56:51
written as Cuda now the nanog gbt here acts as reference code in pytorch to the
3:56:53
acts as reference code in pytorch to the
3:56:53
acts as reference code in pytorch to the C implementation so we're trying to
3:56:55
C implementation so we're trying to
3:56:55
C implementation so we're trying to exactly match up the two but we're
3:56:57
exactly match up the two but we're
3:56:57
exactly match up the two but we're hoping that the C Cuda is faster and of
3:56:59
hoping that the C Cuda is faster and of
3:56:59
hoping that the C Cuda is faster and of course currently that seems to be the
3:57:01
course currently that seems to be the
3:57:01
course currently that seems to be the case um because it is a direct optimized
3:57:04
case um because it is a direct optimized
3:57:04
case um because it is a direct optimized implementation so train gpt2 Pi in LL
3:57:06
implementation so train gpt2 Pi in LL
3:57:06
implementation so train gpt2 Pi in LL M.C is basically the nanog GPT and when
3:57:10
M.C is basically the nanog GPT and when
3:57:10
M.C is basically the nanog GPT and when you scroll through this file you'll find
3:57:12
you scroll through this file you'll find
3:57:12
you scroll through this file you'll find a lot of things that very much look like
3:57:16
a lot of things that very much look like
3:57:16
a lot of things that very much look like um things that we've built up in this
3:57:19
um things that we've built up in this
3:57:19
um things that we've built up in this lecture and then when you look at train
3:57:21
lecture and then when you look at train
3:57:21
lecture and then when you look at train gpt2 docu uh this is the C Cuda
3:57:25
gpt2 docu uh this is the C Cuda
3:57:25
gpt2 docu uh this is the C Cuda implementation so there's a lot of MPI
3:57:27
implementation so there's a lot of MPI
3:57:27
implementation so there's a lot of MPI nickel GPU Cuda
3:57:30
nickel GPU Cuda
3:57:30
nickel GPU Cuda cc++ and you have to be familiar with
3:57:32
cc++ and you have to be familiar with
3:57:32
cc++ and you have to be familiar with that but uh um when this is built up we
3:57:37
that but uh um when this is built up we
3:57:37
that but uh um when this is built up we can actually run the two side by side
3:57:39
can actually run the two side by side
3:57:39
can actually run the two side by side and they're going to produce the exact
3:57:40
and they're going to produce the exact
3:57:40
and they're going to produce the exact same results but lm. C actually runs
3:57:43
same results but lm. C actually runs
3:57:43
same results but lm. C actually runs faster so let's see that so on the left
3:57:45
faster so let's see that so on the left
3:57:45
faster so let's see that so on the left I have pytorch a nanog GPT looking thing
3:57:49
I have pytorch a nanog GPT looking thing
3:57:49
I have pytorch a nanog GPT looking thing on the right I have the llmc call and
3:57:52
on the right I have the llmc call and
3:57:52
on the right I have the llmc call and here I'm going to launch the
3:57:54
here I'm going to launch the
3:57:54
here I'm going to launch the two both of these are going to be
3:57:55
two both of these are going to be
3:57:55
two both of these are going to be running on a single GPU and here I'm
3:57:57
running on a single GPU and here I'm
3:57:57
running on a single GPU and here I'm putting the lm. C on GPU 1 and this one
3:58:00
putting the lm. C on GPU 1 and this one
3:58:00
putting the lm. C on GPU 1 and this one will grab uh gpu0 by default and
3:58:05
will grab uh gpu0 by default and
3:58:05
will grab uh gpu0 by default and then we can see here that lm. c
3:58:08
then we can see here that lm. c
3:58:08
then we can see here that lm. c compiled and then allocate space and
3:58:11
compiled and then allocate space and
3:58:11
compiled and then allocate space and it's
3:58:12
it's
3:58:12
it's stepping so
3:58:15
stepping so
3:58:15
stepping so basically uh meanwhile P torch is still
3:58:17
basically uh meanwhile P torch is still
3:58:17
basically uh meanwhile P torch is still compiling because torch compile is a bit
3:58:19
compiling because torch compile is a bit
3:58:19
compiling because torch compile is a bit slower here than the lm. C nbcc Cuda
3:58:24
slower here than the lm. C nbcc Cuda
3:58:24
slower here than the lm. C nbcc Cuda compile and so this program has already
3:58:26
compile and so this program has already
3:58:26
compile and so this program has already started running and uh we're still
3:58:28
started running and uh we're still
3:58:28
started running and uh we're still waiting here for torch compile now of
3:58:30
waiting here for torch compile now of
3:58:30
waiting here for torch compile now of course uh this is a very specific
3:58:33
course uh this is a very specific
3:58:33
course uh this is a very specific implementation to gpt2 and 3 a pytorch
3:58:35
implementation to gpt2 and 3 a pytorch
3:58:35
implementation to gpt2 and 3 a pytorch is a very general neural network
3:58:37
is a very general neural network
3:58:37
is a very general neural network framework so they're not exactly
3:58:38
framework so they're not exactly
3:58:38
framework so they're not exactly comparable but if you're only interested
3:58:39
comparable but if you're only interested
3:58:39
comparable but if you're only interested in training gpt2 and 3 lm. C is very
3:58:43
in training gpt2 and 3 lm. C is very
3:58:43
in training gpt2 and 3 lm. C is very fast it takes less space it's faster to
3:58:46
fast it takes less space it's faster to
3:58:46
fast it takes less space it's faster to start and it's faster per
3:58:49
start and it's faster per
3:58:49
start and it's faster per step and so P started to Stepping here
3:58:53
step and so P started to Stepping here
3:58:53
step and so P started to Stepping here and as you can see we're running at
3:58:54
and as you can see we're running at
3:58:54
and as you can see we're running at about 223,000 tokens per second here and
3:58:57
about 223,000 tokens per second here and
3:58:57
about 223,000 tokens per second here and about 185,000 tokens per second here um
3:59:02
about 185,000 tokens per second here um
3:59:03
about 185,000 tokens per second here um so quite a bit slower but I don't have
3:59:05
so quite a bit slower but I don't have
3:59:05
so quite a bit slower but I don't have full confidence that I exactly squeezed
3:59:08
full confidence that I exactly squeezed
3:59:08
full confidence that I exactly squeezed out all the juice from the pytorch
3:59:09
out all the juice from the pytorch
3:59:09
out all the juice from the pytorch implementation but the important thing
3:59:11
implementation but the important thing
3:59:11
implementation but the important thing here is notice that if I Aline up the
3:59:14
here is notice that if I Aline up the
3:59:14
here is notice that if I Aline up the steps you will see that the losses and
3:59:16
steps you will see that the losses and
3:59:16
steps you will see that the losses and Norms that are printed between these two
3:59:18
Norms that are printed between these two
3:59:18
Norms that are printed between these two are
3:59:19
are
3:59:19
are identical so on the left we have the pie
3:59:21
identical so on the left we have the pie
3:59:21
identical so on the left we have the pie torch and on the right this C
3:59:23
torch and on the right this C
3:59:24
torch and on the right this C implementation and they're the same
3:59:25
implementation and they're the same
3:59:25
implementation and they're the same except this one runs faster uh so that's
3:59:28
except this one runs faster uh so that's
3:59:28
except this one runs faster uh so that's kind of I wanted to show you also
3:59:30
kind of I wanted to show you also
3:59:30
kind of I wanted to show you also briefly lm. C and this is a parallel
3:59:33
briefly lm. C and this is a parallel
3:59:33
briefly lm. C and this is a parallel implementation and it's also something
3:59:35
implementation and it's also something
3:59:35
implementation and it's also something that you may want to uh play with or
3:59:36
that you may want to uh play with or
3:59:37
that you may want to uh play with or look at and um it's kind of interesting
3:59:39
look at and um it's kind of interesting
3:59:39
look at and um it's kind of interesting okay so at this point I should probably
3:59:40
okay so at this point I should probably
3:59:40
okay so at this point I should probably start wrapping up the video because I
3:59:42
start wrapping up the video because I
3:59:42
start wrapping up the video because I think it's getting way longer than I
3:59:44
think it's getting way longer than I
3:59:44
think it's getting way longer than I anticipated uh but we did Cover a lot of
3:59:46
anticipated uh but we did Cover a lot of
3:59:46
anticipated uh but we did Cover a lot of ground and we built everything from
3:59:48
ground and we built everything from
3:59:48
ground and we built everything from scratch so as a brief summary we were
3:59:50
scratch so as a brief summary we were
3:59:50
scratch so as a brief summary we were looking at the gpt2 and GPT 3
3:59:54
looking at the gpt2 and GPT 3
3:59:55
looking at the gpt2 and GPT 3 papers we were looking at how you set up
3:59:57
papers we were looking at how you set up
3:59:57
papers we were looking at how you set up these training runs uh and all the
3:59:59
these training runs uh and all the
3:59:59
these training runs uh and all the considerations involved we wrote
4:00:01
considerations involved we wrote
4:00:01
considerations involved we wrote everything from scratch and then we saw
4:00:03
everything from scratch and then we saw
4:00:03
everything from scratch and then we saw that over the duration of either a
4:00:04
that over the duration of either a
4:00:04
that over the duration of either a 2-hour training run or an overnight run
4:00:07
2-hour training run or an overnight run
4:00:07
2-hour training run or an overnight run we can actually match the 124 million
4:00:09
we can actually match the 124 million
4:00:09
we can actually match the 124 million parameter checkpoints of gbt2 and gpt3
4:00:12
parameter checkpoints of gbt2 and gpt3
4:00:12
parameter checkpoints of gbt2 and gpt3 uh to a very large extent
4:00:14
uh to a very large extent
4:00:14
uh to a very large extent um in principle the code that we wrote
4:00:16
um in principle the code that we wrote
4:00:16
um in principle the code that we wrote would be able to train even bigger
4:00:18
would be able to train even bigger
4:00:18
would be able to train even bigger models if you have the patients or the
4:00:19
models if you have the patients or the
4:00:19
models if you have the patients or the Computing resources uh and so you could
4:00:21
Computing resources uh and so you could
4:00:21
Computing resources uh and so you could potentially think about training some of
4:00:23
potentially think about training some of
4:00:23
potentially think about training some of the bigger checkpoints as well um there
4:00:26
the bigger checkpoints as well um there
4:00:26
the bigger checkpoints as well um there are a few remaining issues to address
4:00:28
are a few remaining issues to address
4:00:28
are a few remaining issues to address what's happening with the loss here
4:00:30
what's happening with the loss here
4:00:30
what's happening with the loss here which I suspect has to do with the fine
4:00:31
which I suspect has to do with the fine
4:00:31
which I suspect has to do with the fine web edu data sampling uh why can't we
4:00:34
web edu data sampling uh why can't we
4:00:34
web edu data sampling uh why can't we turn on Torch compile uh it currently
4:00:36
turn on Torch compile uh it currently
4:00:36
turn on Torch compile uh it currently breaks generation and H swag what's up
4:00:39
breaks generation and H swag what's up
4:00:39
breaks generation and H swag what's up with that in the data loader we should
4:00:41
with that in the data loader we should
4:00:41
with that in the data loader we should probably be permuting our data when we
4:00:43
probably be permuting our data when we
4:00:43
probably be permuting our data when we reach boundaries so there's a few more
4:00:45
reach boundaries so there's a few more
4:00:45
reach boundaries so there's a few more issues like that and I expect to be
4:00:47
issues like that and I expect to be
4:00:47
issues like that and I expect to be documenting some of those over time in
4:00:49
documenting some of those over time in
4:00:49
documenting some of those over time in the uh build n GPT repository here which
4:00:53
the uh build n GPT repository here which
4:00:53
the uh build n GPT repository here which I'm going to be releasing with this
4:00:55
I'm going to be releasing with this
4:00:55
I'm going to be releasing with this video if you have any questions or like
4:00:57
video if you have any questions or like
4:00:57
video if you have any questions or like to talk about anything that we covered
4:00:59
to talk about anything that we covered
4:00:59
to talk about anything that we covered please go to discussions tab uh so we
4:01:02
please go to discussions tab uh so we
4:01:02
please go to discussions tab uh so we can talk here uh or please go to issues
4:01:04
can talk here uh or please go to issues
4:01:04
can talk here uh or please go to issues or pull request pull requests um
4:01:07
or pull request pull requests um
4:01:07
or pull request pull requests um depending on what you'd like to
4:01:08
depending on what you'd like to
4:01:08
depending on what you'd like to contribute or also have a look at the uh
4:01:11
contribute or also have a look at the uh
4:01:11
contribute or also have a look at the uh Zero to Hero Discord and uh I'm going to
4:01:14
Zero to Hero Discord and uh I'm going to
4:01:14
Zero to Hero Discord and uh I'm going to be hanging out here on N GPT
4:01:17
be hanging out here on N GPT
4:01:17
be hanging out here on N GPT um otherwise for now I'm pretty happy
4:01:20
um otherwise for now I'm pretty happy
4:01:20
um otherwise for now I'm pretty happy about where we got um and I hope you
4:01:23
about where we got um and I hope you
4:01:23
about where we got um and I hope you enjoyed the video and I will see you
4:01:24
enjoyed the video and I will see you
4:01:25
enjoyed the video and I will see you later