View on GitHub
GitHub
Neural Networks: Zero to Hero
Building makemore Part 4: Becoming a Backprop Ninja
Loading player
Notes
Transcript
6164 segments
0:00
hi everyone so today we are once again
0:02
hi everyone so today we are once again
0:02
hi everyone so today we are once again continuing our implementation of make
0:04
continuing our implementation of make
0:04
continuing our implementation of make more now so far we've come up to here
0:07
more now so far we've come up to here
0:07
more now so far we've come up to here montalia perceptrons and our neural net
0:09
montalia perceptrons and our neural net
0:09
montalia perceptrons and our neural net looked like this and we were
0:11
looked like this and we were
0:11
looked like this and we were implementing this over the last few
0:12
implementing this over the last few
0:12
implementing this over the last few lectures
0:13
lectures
0:13
lectures now I'm sure everyone is very excited to
0:15
now I'm sure everyone is very excited to
0:15
now I'm sure everyone is very excited to go into recurring neural networks and
0:16
go into recurring neural networks and
0:16
go into recurring neural networks and all of their variants and how they work
0:18
all of their variants and how they work
0:18
all of their variants and how they work and the diagrams look cool and it's very
0:20
and the diagrams look cool and it's very
0:20
and the diagrams look cool and it's very exciting and interesting and we're going
0:21
exciting and interesting and we're going
0:21
exciting and interesting and we're going to get a better result but unfortunately
0:23
to get a better result but unfortunately
0:23
to get a better result but unfortunately I think we have to remain here for one
0:25
I think we have to remain here for one
0:25
I think we have to remain here for one more lecture and the reason for that is
0:28
more lecture and the reason for that is
0:28
more lecture and the reason for that is we've already trained this multilio
0:30
we've already trained this multilio
0:30
we've already trained this multilio perceptron right and we are getting
0:31
perceptron right and we are getting
0:31
perceptron right and we are getting pretty good loss and I think we have a
0:33
pretty good loss and I think we have a
0:33
pretty good loss and I think we have a pretty decent understanding of the
0:34
pretty decent understanding of the
0:34
pretty decent understanding of the architecture and how it works but the
0:37
architecture and how it works but the
0:37
architecture and how it works but the line of code here that I take an issue
0:39
line of code here that I take an issue
0:39
line of code here that I take an issue with is here lost up backward that is we
0:42
with is here lost up backward that is we
0:42
with is here lost up backward that is we are taking a pytorch auto grad and using
0:45
are taking a pytorch auto grad and using
0:45
are taking a pytorch auto grad and using it to calculate all of our gradients
0:46
it to calculate all of our gradients
0:46
it to calculate all of our gradients along the way and I would like to remove
0:48
along the way and I would like to remove
0:48
along the way and I would like to remove the use of lost at backward and I would
0:50
the use of lost at backward and I would
0:50
the use of lost at backward and I would like us to write our backward pass
0:52
like us to write our backward pass
0:52
like us to write our backward pass manually on the level of tensors and I
0:55
manually on the level of tensors and I
0:55
manually on the level of tensors and I think that this is a very useful
0:56
think that this is a very useful
0:56
think that this is a very useful exercise for the following reasons
0:58
exercise for the following reasons
0:58
exercise for the following reasons I actually have an entire blog post on
1:00
I actually have an entire blog post on
1:00
I actually have an entire blog post on this topic but I'd like to call back
1:02
this topic but I'd like to call back
1:02
this topic but I'd like to call back propagation a leaky abstraction
1:05
propagation a leaky abstraction
1:05
propagation a leaky abstraction and what I mean by that is back
1:07
and what I mean by that is back
1:07
and what I mean by that is back propagation does doesn't just make your
1:09
propagation does doesn't just make your
1:09
propagation does doesn't just make your neural networks just work magically it's
1:11
neural networks just work magically it's
1:11
neural networks just work magically it's not the case they can just Stack Up
1:12
not the case they can just Stack Up
1:12
not the case they can just Stack Up arbitrary Lego blocks of differentiable
1:14
arbitrary Lego blocks of differentiable
1:14
arbitrary Lego blocks of differentiable functions and just cross your fingers
1:16
functions and just cross your fingers
1:16
functions and just cross your fingers and back propagate and everything is
1:17
and back propagate and everything is
1:17
and back propagate and everything is great things don't just work
1:19
great things don't just work
1:19
great things don't just work automatically it is a leaky abstraction
1:22
automatically it is a leaky abstraction
1:22
automatically it is a leaky abstraction in the sense that you can shoot yourself
1:23
in the sense that you can shoot yourself
1:23
in the sense that you can shoot yourself in the foot if you do not understanding
1:25
in the foot if you do not understanding
1:25
in the foot if you do not understanding its internals it will magically not work
1:28
its internals it will magically not work
1:28
its internals it will magically not work or not work optimally and you will need
1:31
or not work optimally and you will need
1:31
or not work optimally and you will need to understand how it works under the
1:32
to understand how it works under the
1:32
to understand how it works under the hood if you're hoping to debug it and if
1:34
hood if you're hoping to debug it and if
1:34
hood if you're hoping to debug it and if you are hoping to address it in your
1:36
you are hoping to address it in your
1:36
you are hoping to address it in your neural nut
1:37
neural nut
1:37
neural nut um so this blog post here from a while
1:39
um so this blog post here from a while
1:39
um so this blog post here from a while ago goes into some of those examples so
1:42
ago goes into some of those examples so
1:42
ago goes into some of those examples so for example we've already covered them
1:43
for example we've already covered them
1:43
for example we've already covered them some of them already for example the
1:46
some of them already for example the
1:46
some of them already for example the flat tails of these functions and how
1:48
flat tails of these functions and how
1:48
flat tails of these functions and how you do not want to saturate them too
1:51
you do not want to saturate them too
1:51
you do not want to saturate them too much because your gradients will die the
1:53
much because your gradients will die the
1:53
much because your gradients will die the case of dead neurons which I've already
1:55
case of dead neurons which I've already
1:55
case of dead neurons which I've already covered as well
1:56
covered as well
1:56
covered as well the case of exploding or Vanishing
1:58
the case of exploding or Vanishing
1:58
the case of exploding or Vanishing gradients in the case of repair neural
2:00
gradients in the case of repair neural
2:00
gradients in the case of repair neural networks which we are about to cover
2:02
networks which we are about to cover
2:02
networks which we are about to cover and then also you will often come across
2:05
and then also you will often come across
2:05
and then also you will often come across some examples in the wild
2:07
some examples in the wild
2:07
some examples in the wild this is a snippet that I found uh in a
2:10
this is a snippet that I found uh in a
2:10
this is a snippet that I found uh in a random code base on the internet where
2:11
random code base on the internet where
2:11
random code base on the internet where they actually have like a very subtle
2:13
they actually have like a very subtle
2:13
they actually have like a very subtle but pretty major bug in their
2:15
but pretty major bug in their
2:15
but pretty major bug in their implementation and the bug points at the
2:18
implementation and the bug points at the
2:18
implementation and the bug points at the fact that the author of this code does
2:20
fact that the author of this code does
2:20
fact that the author of this code does not actually understand by propagation
2:21
not actually understand by propagation
2:21
not actually understand by propagation so they're trying to do here is they're
2:23
so they're trying to do here is they're
2:23
so they're trying to do here is they're trying to clip the loss at a certain
2:25
trying to clip the loss at a certain
2:25
trying to clip the loss at a certain maximum value but actually what they're
2:27
maximum value but actually what they're
2:27
maximum value but actually what they're trying to do is they're trying to
2:28
trying to do is they're trying to
2:28
trying to do is they're trying to collect the gradients to have a maximum
2:30
collect the gradients to have a maximum
2:30
collect the gradients to have a maximum value instead of trying to clip the loss
2:32
value instead of trying to clip the loss
2:32
value instead of trying to clip the loss at a maximum value and
2:34
at a maximum value and
2:34
at a maximum value and um indirectly they're basically causing
2:36
um indirectly they're basically causing
2:36
um indirectly they're basically causing some of the outliers to be actually
2:38
some of the outliers to be actually
2:38
some of the outliers to be actually ignored because when you clip a loss of
2:41
ignored because when you clip a loss of
2:41
ignored because when you clip a loss of an outlier you are setting its gradient
2:43
an outlier you are setting its gradient
2:43
an outlier you are setting its gradient to zero and so have a look through this
2:46
to zero and so have a look through this
2:46
to zero and so have a look through this and read through it but there's
2:48
and read through it but there's
2:48
and read through it but there's basically a bunch of subtle issues that
2:50
basically a bunch of subtle issues that
2:50
basically a bunch of subtle issues that you're going to avoid if you actually
2:51
you're going to avoid if you actually
2:51
you're going to avoid if you actually know what you're doing and that's why I
2:53
know what you're doing and that's why I
2:53
know what you're doing and that's why I don't think it's the case that because
2:55
don't think it's the case that because
2:55
don't think it's the case that because pytorch or other Frameworks offer
2:56
pytorch or other Frameworks offer
2:56
pytorch or other Frameworks offer autograd it is okay for us to ignore how
2:59
autograd it is okay for us to ignore how
2:59
autograd it is okay for us to ignore how it works
3:00
it works
3:00
it works now we've actually already covered
3:02
now we've actually already covered
3:02
now we've actually already covered covered autograd and we wrote micrograd
3:04
covered autograd and we wrote micrograd
3:04
covered autograd and we wrote micrograd but micrograd was an autograd engine
3:07
but micrograd was an autograd engine
3:07
but micrograd was an autograd engine only on the level of individual scalars
3:09
only on the level of individual scalars
3:09
only on the level of individual scalars so the atoms were single individual
3:11
so the atoms were single individual
3:11
so the atoms were single individual numbers and uh you know I don't think
3:13
numbers and uh you know I don't think
3:13
numbers and uh you know I don't think it's enough and I'd like us to basically
3:14
it's enough and I'd like us to basically
3:14
it's enough and I'd like us to basically think about back propagation on level of
3:16
think about back propagation on level of
3:16
think about back propagation on level of tensors as well and so in a summary I
3:19
tensors as well and so in a summary I
3:19
tensors as well and so in a summary I think it's a good exercise I think it is
3:21
think it's a good exercise I think it is
3:21
think it's a good exercise I think it is very very valuable you're going to
3:23
very very valuable you're going to
3:23
very very valuable you're going to become better at debugging neural
3:25
become better at debugging neural
3:25
become better at debugging neural networks and making sure that you
3:27
networks and making sure that you
3:27
networks and making sure that you understand what you're doing it is going
3:28
understand what you're doing it is going
3:28
understand what you're doing it is going to make everything fully explicit so
3:30
to make everything fully explicit so
3:30
to make everything fully explicit so you're not going to be nervous about
3:31
you're not going to be nervous about
3:31
you're not going to be nervous about what is hidden away from you and
3:33
what is hidden away from you and
3:33
what is hidden away from you and basically in general we're going to
3:34
basically in general we're going to
3:34
basically in general we're going to emerge stronger and so let's get into it
3:37
emerge stronger and so let's get into it
3:37
emerge stronger and so let's get into it a bit of a fun historical note here is
3:40
a bit of a fun historical note here is
3:40
a bit of a fun historical note here is that today writing your backward pass by
3:42
that today writing your backward pass by
3:42
that today writing your backward pass by hand and manually is not recommended and
3:43
hand and manually is not recommended and
3:43
hand and manually is not recommended and no one does it except for the purposes
3:45
no one does it except for the purposes
3:45
no one does it except for the purposes of exercise but about 10 years ago in
3:48
of exercise but about 10 years ago in
3:48
of exercise but about 10 years ago in deep learning this was fairly standard
3:49
deep learning this was fairly standard
3:49
deep learning this was fairly standard and in fact pervasive so at the time
3:52
and in fact pervasive so at the time
3:52
and in fact pervasive so at the time everyone used to write their own
3:53
everyone used to write their own
3:53
everyone used to write their own backward pass by hand manually including
3:55
backward pass by hand manually including
3:55
backward pass by hand manually including myself and it's just what you would do
3:57
myself and it's just what you would do
3:57
myself and it's just what you would do so we used to ride backward pass by hand
3:59
so we used to ride backward pass by hand
3:59
so we used to ride backward pass by hand and now everyone just calls lost that
4:01
and now everyone just calls lost that
4:01
and now everyone just calls lost that backward uh we've lost something I want
4:04
backward uh we've lost something I want
4:04
backward uh we've lost something I want to give you a few examples of this so
4:07
to give you a few examples of this so
4:07
to give you a few examples of this so here's a 2006 paper from Jeff Hinton and
4:11
here's a 2006 paper from Jeff Hinton and
4:11
here's a 2006 paper from Jeff Hinton and Russell selectinov in science that was
4:13
Russell selectinov in science that was
4:13
Russell selectinov in science that was influential at the time and this was
4:15
influential at the time and this was
4:15
influential at the time and this was training some architectures called
4:17
training some architectures called
4:17
training some architectures called restricted bolstery machines and
4:19
restricted bolstery machines and
4:19
restricted bolstery machines and basically it's an auto encoder trained
4:22
basically it's an auto encoder trained
4:22
basically it's an auto encoder trained here and this is from roughly 2010 I had
4:26
here and this is from roughly 2010 I had
4:26
here and this is from roughly 2010 I had a library for training researchable
4:27
a library for training researchable
4:27
a library for training researchable machines and this was at the time
4:30
machines and this was at the time
4:30
machines and this was at the time written in Matlab so python was not used
4:32
written in Matlab so python was not used
4:32
written in Matlab so python was not used for deep learning pervasively it was all
4:34
for deep learning pervasively it was all
4:34
for deep learning pervasively it was all Matlab and Matlab was this a scientific
4:36
Matlab and Matlab was this a scientific
4:36
Matlab and Matlab was this a scientific Computing package that everyone would
4:39
Computing package that everyone would
4:39
Computing package that everyone would use so we would write Matlab which is
4:41
use so we would write Matlab which is
4:41
use so we would write Matlab which is barely a programming language as well
4:44
barely a programming language as well
4:44
barely a programming language as well but I've had a very convenient tensor
4:46
but I've had a very convenient tensor
4:46
but I've had a very convenient tensor class and was this a Computing
4:48
class and was this a Computing
4:48
class and was this a Computing environment and you would run here it
4:49
environment and you would run here it
4:49
environment and you would run here it would all run on a CPU of course but you
4:51
would all run on a CPU of course but you
4:51
would all run on a CPU of course but you would have very nice plots to go with it
4:53
would have very nice plots to go with it
4:53
would have very nice plots to go with it and a built-in debugger and it was
4:54
and a built-in debugger and it was
4:54
and a built-in debugger and it was pretty nice now the code in this package
4:57
pretty nice now the code in this package
4:57
pretty nice now the code in this package in 2010 that I wrote for fitting
5:00
in 2010 that I wrote for fitting
5:00
in 2010 that I wrote for fitting research multiple machines to a large
5:02
research multiple machines to a large
5:03
research multiple machines to a large extent is recognizable but I wanted to
5:05
extent is recognizable but I wanted to
5:05
extent is recognizable but I wanted to show you how you would well I'm creating
5:07
show you how you would well I'm creating
5:07
show you how you would well I'm creating the data in the XY batches I'm
5:09
the data in the XY batches I'm
5:09
the data in the XY batches I'm initializing the neural nut so it's got
5:11
initializing the neural nut so it's got
5:11
initializing the neural nut so it's got weights and biases just like we're used
5:13
weights and biases just like we're used
5:13
weights and biases just like we're used to and then this is the training Loop
5:15
to and then this is the training Loop
5:15
to and then this is the training Loop where we actually do the forward pass
5:17
where we actually do the forward pass
5:17
where we actually do the forward pass and then here at this time they didn't
5:19
and then here at this time they didn't
5:19
and then here at this time they didn't even necessarily use back propagation to
5:21
even necessarily use back propagation to
5:21
even necessarily use back propagation to train neural networks so this in
5:23
train neural networks so this in
5:23
train neural networks so this in particular implements contrastive
5:25
particular implements contrastive
5:25
particular implements contrastive Divergence which estimates a gradient
5:28
Divergence which estimates a gradient
5:28
Divergence which estimates a gradient and then here we take that gradient and
5:30
and then here we take that gradient and
5:30
and then here we take that gradient and use it for a parameter update along the
5:32
use it for a parameter update along the
5:32
use it for a parameter update along the lines that we're used to
5:34
lines that we're used to
5:34
lines that we're used to um yeah here
5:36
um yeah here
5:36
um yeah here but you can see that basically people
5:38
but you can see that basically people
5:38
but you can see that basically people are meddling with these gradients uh
5:39
are meddling with these gradients uh
5:39
are meddling with these gradients uh directly and inline and themselves uh it
5:41
directly and inline and themselves uh it
5:41
directly and inline and themselves uh it wasn't that common to use an auto grad
5:43
wasn't that common to use an auto grad
5:43
wasn't that common to use an auto grad engine here's one more example from a
5:45
engine here's one more example from a
5:45
engine here's one more example from a paper of mine from 2014
5:47
paper of mine from 2014
5:47
paper of mine from 2014 um called the fragmented embeddings
5:49
um called the fragmented embeddings
5:49
um called the fragmented embeddings and here what I was doing is I was
5:51
and here what I was doing is I was
5:51
and here what I was doing is I was aligning images and text
5:53
aligning images and text
5:53
aligning images and text um and so it's kind of like a clip if
5:55
um and so it's kind of like a clip if
5:55
um and so it's kind of like a clip if you're familiar with it but instead of
5:56
you're familiar with it but instead of
5:56
you're familiar with it but instead of working on the level of entire images
5:58
working on the level of entire images
5:58
working on the level of entire images and entire sentences it was working on
6:00
and entire sentences it was working on
6:00
and entire sentences it was working on the level of individual objects and
6:01
the level of individual objects and
6:01
the level of individual objects and little pieces of sentences and I was
6:03
little pieces of sentences and I was
6:03
little pieces of sentences and I was embedding them and then calculating very
6:05
embedding them and then calculating very
6:05
embedding them and then calculating very much like a clip-like loss and I dig up
6:08
much like a clip-like loss and I dig up
6:08
much like a clip-like loss and I dig up the code from 2014 of how I implemented
6:10
the code from 2014 of how I implemented
6:10
the code from 2014 of how I implemented this and it was already in numpy and
6:13
this and it was already in numpy and
6:13
this and it was already in numpy and python
6:14
python
6:14
python and here I'm planting the cost function
6:16
and here I'm planting the cost function
6:16
and here I'm planting the cost function and it was standard to implement not
6:19
and it was standard to implement not
6:19
and it was standard to implement not just the cost but also the backward pass
6:20
just the cost but also the backward pass
6:20
just the cost but also the backward pass manually so here I'm calculating the
6:23
manually so here I'm calculating the
6:23
manually so here I'm calculating the image embeddings sentence embeddings the
6:26
image embeddings sentence embeddings the
6:26
image embeddings sentence embeddings the loss function I calculate this course
6:28
loss function I calculate this course
6:28
loss function I calculate this course this is the loss function and then once
6:31
this is the loss function and then once
6:31
this is the loss function and then once I have the loss function I do the
6:32
I have the loss function I do the
6:32
I have the loss function I do the backward pass right here so I backward
6:34
backward pass right here so I backward
6:34
backward pass right here so I backward through the loss function and through
6:36
through the loss function and through
6:36
through the loss function and through the neural nut and I append
6:38
the neural nut and I append
6:38
the neural nut and I append regularization so everything was done by
6:41
regularization so everything was done by
6:41
regularization so everything was done by hand manually and you were just right
6:42
hand manually and you were just right
6:42
hand manually and you were just right out the backward pass and then you would
6:44
out the backward pass and then you would
6:44
out the backward pass and then you would use a gradient Checker to make sure that
6:46
use a gradient Checker to make sure that
6:46
use a gradient Checker to make sure that your numerical estimate of the gradient
6:47
your numerical estimate of the gradient
6:47
your numerical estimate of the gradient agrees with the one you calculated
6:49
agrees with the one you calculated
6:49
agrees with the one you calculated during back propagation so this was very
6:51
during back propagation so this was very
6:51
during back propagation so this was very standard for a long time but today of
6:53
standard for a long time but today of
6:53
standard for a long time but today of course it is standard to use an auto
6:55
course it is standard to use an auto
6:55
course it is standard to use an auto grad engine
6:56
grad engine
6:56
grad engine um but it was definitely useful and I
6:58
um but it was definitely useful and I
6:58
um but it was definitely useful and I think people sort of understood how
6:59
think people sort of understood how
6:59
think people sort of understood how these neural networks work on a very
7:01
these neural networks work on a very
7:01
these neural networks work on a very intuitive level and so I think it's a
7:03
intuitive level and so I think it's a
7:03
intuitive level and so I think it's a good exercise again and this is where we
7:04
good exercise again and this is where we
7:04
good exercise again and this is where we want to be okay so just as a reminder
7:06
want to be okay so just as a reminder
7:06
want to be okay so just as a reminder from our previous lecture this is The
7:08
from our previous lecture this is The
7:08
from our previous lecture this is The jupyter Notebook that we implemented at
7:09
jupyter Notebook that we implemented at
7:09
jupyter Notebook that we implemented at the time and
7:11
the time and
7:11
the time and we're going to keep everything the same
7:13
we're going to keep everything the same
7:13
we're going to keep everything the same so we're still going to have a two layer
7:14
so we're still going to have a two layer
7:15
so we're still going to have a two layer multiplayer perceptron with a batch
7:16
multiplayer perceptron with a batch
7:16
multiplayer perceptron with a batch normalization layer so the forward pass
7:18
normalization layer so the forward pass
7:18
normalization layer so the forward pass will be basically identical to this
7:20
will be basically identical to this
7:20
will be basically identical to this lecture but here we're going to get rid
7:22
lecture but here we're going to get rid
7:22
lecture but here we're going to get rid of lost and backward and instead we're
7:23
of lost and backward and instead we're
7:23
of lost and backward and instead we're going to write the backward pass
7:24
going to write the backward pass
7:24
going to write the backward pass manually
7:26
manually
7:26
manually now here's the starter code for this
7:27
now here's the starter code for this
7:27
now here's the starter code for this lecture we are becoming a back prop
7:29
lecture we are becoming a back prop
7:29
lecture we are becoming a back prop ninja in this notebook
7:31
ninja in this notebook
7:31
ninja in this notebook and the first few cells here are
7:34
and the first few cells here are
7:34
and the first few cells here are identical to what we are used to so we
7:35
identical to what we are used to so we
7:36
identical to what we are used to so we are doing some imports loading the data
7:37
are doing some imports loading the data
7:37
are doing some imports loading the data set and processing the data set none of
7:40
set and processing the data set none of
7:40
set and processing the data set none of this changed
7:41
this changed
7:41
this changed now here I'm introducing a utility
7:43
now here I'm introducing a utility
7:43
now here I'm introducing a utility function that we're going to use later
7:44
function that we're going to use later
7:44
function that we're going to use later to compare the gradients so in
7:46
to compare the gradients so in
7:46
to compare the gradients so in particular we are going to have the
7:47
particular we are going to have the
7:47
particular we are going to have the gradients that we estimate manually
7:49
gradients that we estimate manually
7:49
gradients that we estimate manually ourselves and we're going to have
7:50
ourselves and we're going to have
7:50
ourselves and we're going to have gradients that Pi torch calculates and
7:53
gradients that Pi torch calculates and
7:53
gradients that Pi torch calculates and we're going to be checking for
7:54
we're going to be checking for
7:54
we're going to be checking for correctness assuming of course that
7:55
correctness assuming of course that
7:55
correctness assuming of course that pytorch is correct
7:58
pytorch is correct
7:58
pytorch is correct um then here we have the initialization
8:00
um then here we have the initialization
8:00
um then here we have the initialization that we are quite used to so we have our
8:02
that we are quite used to so we have our
8:03
that we are quite used to so we have our embedding table for the characters the
8:05
embedding table for the characters the
8:05
embedding table for the characters the first layer second layer and the batch
8:06
first layer second layer and the batch
8:06
first layer second layer and the batch normalization in between
8:08
normalization in between
8:08
normalization in between and here's where we create all the
8:09
and here's where we create all the
8:09
and here's where we create all the parameters now you will note that I
8:11
parameters now you will note that I
8:11
parameters now you will note that I changed the initialization a little bit
8:13
changed the initialization a little bit
8:13
changed the initialization a little bit uh to be small numbers so normally you
8:16
uh to be small numbers so normally you
8:16
uh to be small numbers so normally you would set the biases to be all zero here
8:18
would set the biases to be all zero here
8:18
would set the biases to be all zero here I am setting them to be small random
8:20
I am setting them to be small random
8:20
I am setting them to be small random numbers and I'm doing this because
8:22
numbers and I'm doing this because
8:22
numbers and I'm doing this because if your variables are initialized to
8:24
if your variables are initialized to
8:24
if your variables are initialized to exactly zero sometimes what can happen
8:26
exactly zero sometimes what can happen
8:26
exactly zero sometimes what can happen is that can mask an incorrect
8:28
is that can mask an incorrect
8:28
is that can mask an incorrect implementation of a gradient
8:30
implementation of a gradient
8:30
implementation of a gradient um because uh when everything is zero it
8:32
um because uh when everything is zero it
8:32
um because uh when everything is zero it sort of like simplifies and gives you a
8:34
sort of like simplifies and gives you a
8:34
sort of like simplifies and gives you a much simpler expression of the gradient
8:35
much simpler expression of the gradient
8:35
much simpler expression of the gradient than you would otherwise get and so by
8:37
than you would otherwise get and so by
8:37
than you would otherwise get and so by making it small numbers I'm trying to
8:39
making it small numbers I'm trying to
8:39
making it small numbers I'm trying to unmask those potential errors in these
8:41
unmask those potential errors in these
8:41
unmask those potential errors in these calculations
8:43
calculations
8:43
calculations you also notice that I'm using uh B1 in
8:46
you also notice that I'm using uh B1 in
8:46
you also notice that I'm using uh B1 in the first layer I'm using a bias despite
8:48
the first layer I'm using a bias despite
8:48
the first layer I'm using a bias despite batch normalization right afterwards
8:50
batch normalization right afterwards
8:50
batch normalization right afterwards um so this would typically not be what
8:52
um so this would typically not be what
8:52
um so this would typically not be what you do because we talked about the fact
8:54
you do because we talked about the fact
8:54
you do because we talked about the fact that you don't need the bias but I'm
8:55
that you don't need the bias but I'm
8:55
that you don't need the bias but I'm doing this here just for fun
8:57
doing this here just for fun
8:57
doing this here just for fun um because we're going to have a
8:58
um because we're going to have a
8:58
um because we're going to have a gradient with respect to it and we can
9:00
gradient with respect to it and we can
9:00
gradient with respect to it and we can check that we are still calculating it
9:01
check that we are still calculating it
9:01
check that we are still calculating it correctly even though this bias is
9:03
correctly even though this bias is
9:03
correctly even though this bias is asparious
9:05
asparious
9:05
asparious so here I'm calculating a single batch
9:07
so here I'm calculating a single batch
9:07
so here I'm calculating a single batch and then here I'm doing a forward pass
9:10
and then here I'm doing a forward pass
9:10
and then here I'm doing a forward pass now you'll notice that the forward pass
9:11
now you'll notice that the forward pass
9:11
now you'll notice that the forward pass is significantly expanded from what we
9:13
is significantly expanded from what we
9:13
is significantly expanded from what we are used to here the forward pass was
9:15
are used to here the forward pass was
9:15
are used to here the forward pass was just
9:16
just
9:16
just um here
9:17
um here
9:17
um here now the reason that the forward pass is
9:19
now the reason that the forward pass is
9:19
now the reason that the forward pass is longer is for two reasons number one
9:22
longer is for two reasons number one
9:22
longer is for two reasons number one here we just had an F dot cross entropy
9:23
here we just had an F dot cross entropy
9:24
here we just had an F dot cross entropy but here I am bringing back a explicit
9:26
but here I am bringing back a explicit
9:26
but here I am bringing back a explicit implementation of the loss function
9:28
implementation of the loss function
9:28
implementation of the loss function and number two
9:29
and number two
9:29
and number two I've broken up the implementation into
9:32
I've broken up the implementation into
9:32
I've broken up the implementation into manageable chunks so we have a lot a lot
9:35
manageable chunks so we have a lot a lot
9:35
manageable chunks so we have a lot a lot more intermediate tensors along the way
9:37
more intermediate tensors along the way
9:37
more intermediate tensors along the way in the forward pass and that's because
9:38
in the forward pass and that's because
9:38
in the forward pass and that's because we are about to go backwards and
9:40
we are about to go backwards and
9:40
we are about to go backwards and calculate the gradients in this back
9:42
calculate the gradients in this back
9:42
calculate the gradients in this back propagation from the bottom to the top
9:45
propagation from the bottom to the top
9:45
propagation from the bottom to the top so we're going to go upwards and just
9:48
so we're going to go upwards and just
9:48
so we're going to go upwards and just like we have for example the lock props
9:49
like we have for example the lock props
9:49
like we have for example the lock props tensor in a forward pass in the backward
9:51
tensor in a forward pass in the backward
9:51
tensor in a forward pass in the backward pass we're going to have a d-lock probes
9:53
pass we're going to have a d-lock probes
9:53
pass we're going to have a d-lock probes which is going to store the derivative
9:55
which is going to store the derivative
9:55
which is going to store the derivative of the loss with respect to the lock
9:56
of the loss with respect to the lock
9:56
of the loss with respect to the lock props tensor and so we're going to be
9:58
props tensor and so we're going to be
9:58
props tensor and so we're going to be prepending D to every one of these
10:00
prepending D to every one of these
10:00
prepending D to every one of these tensors and calculating it along the way
10:02
tensors and calculating it along the way
10:02
tensors and calculating it along the way of this back propagation
10:04
of this back propagation
10:04
of this back propagation so as an example we have a b and raw
10:07
so as an example we have a b and raw
10:07
so as an example we have a b and raw here we're going to be calculating a DB
10:09
here we're going to be calculating a DB
10:09
here we're going to be calculating a DB in raw so here I'm telling pytorch that
10:12
in raw so here I'm telling pytorch that
10:12
in raw so here I'm telling pytorch that we want to retain the grad of all these
10:14
we want to retain the grad of all these
10:14
we want to retain the grad of all these intermediate values because here in
10:16
intermediate values because here in
10:16
intermediate values because here in exercise one we're going to calculate
10:18
exercise one we're going to calculate
10:18
exercise one we're going to calculate the backward pass so we're going to
10:20
the backward pass so we're going to
10:20
the backward pass so we're going to calculate all these D values D variables
10:22
calculate all these D values D variables
10:22
calculate all these D values D variables and use the CNP function I've introduced
10:25
and use the CNP function I've introduced
10:25
and use the CNP function I've introduced above to check our correctness with
10:26
above to check our correctness with
10:26
above to check our correctness with respect to what pi torch is telling us
10:29
respect to what pi torch is telling us
10:29
respect to what pi torch is telling us this is going to be exercise one uh
10:31
this is going to be exercise one uh
10:31
this is going to be exercise one uh where we sort of back propagate through
10:32
where we sort of back propagate through
10:32
where we sort of back propagate through this entire graph
10:34
this entire graph
10:34
this entire graph now just to give you a very quick
10:36
now just to give you a very quick
10:36
now just to give you a very quick preview of what's going to happen in
10:37
preview of what's going to happen in
10:37
preview of what's going to happen in exercise two and below here we have
10:40
exercise two and below here we have
10:40
exercise two and below here we have fully broken up the loss and back
10:43
fully broken up the loss and back
10:43
fully broken up the loss and back propagated through it manually in all
10:45
propagated through it manually in all
10:45
propagated through it manually in all the little Atomic pieces that make it up
10:47
the little Atomic pieces that make it up
10:47
the little Atomic pieces that make it up but here we're going to collapse the
10:49
but here we're going to collapse the
10:49
but here we're going to collapse the laws into a single cross-entropy call
10:50
laws into a single cross-entropy call
10:50
laws into a single cross-entropy call and instead we're going to analytically
10:53
and instead we're going to analytically
10:53
and instead we're going to analytically derive using math and paper and pencil
10:56
derive using math and paper and pencil
10:56
derive using math and paper and pencil the gradient of the loss with respect to
10:59
the gradient of the loss with respect to
10:59
the gradient of the loss with respect to the logits and instead of back
11:01
the logits and instead of back
11:01
the logits and instead of back propagating through all of its little
11:02
propagating through all of its little
11:02
propagating through all of its little chunks one at a time we're just going to
11:04
chunks one at a time we're just going to
11:04
chunks one at a time we're just going to analytically derive what that gradient
11:05
analytically derive what that gradient
11:05
analytically derive what that gradient is and we're going to implement that
11:07
is and we're going to implement that
11:07
is and we're going to implement that which is much more efficient as we'll
11:09
which is much more efficient as we'll
11:09
which is much more efficient as we'll see in the in a bit
11:10
see in the in a bit
11:10
see in the in a bit then we're going to do the exact same
11:12
then we're going to do the exact same
11:12
then we're going to do the exact same thing for patch normalization so instead
11:14
thing for patch normalization so instead
11:14
thing for patch normalization so instead of breaking up bass drum into all the
11:16
of breaking up bass drum into all the
11:16
of breaking up bass drum into all the old tiny components we're going to use
11:18
old tiny components we're going to use
11:18
old tiny components we're going to use uh pen and paper and Mathematics and
11:20
uh pen and paper and Mathematics and
11:20
uh pen and paper and Mathematics and calculus to derive the gradient through
11:22
calculus to derive the gradient through
11:22
calculus to derive the gradient through the bachelor Bachelor layer so we're
11:25
the bachelor Bachelor layer so we're
11:25
the bachelor Bachelor layer so we're going to calculate the backward
11:26
going to calculate the backward
11:27
going to calculate the backward passthrough bathroom layer in a much
11:28
passthrough bathroom layer in a much
11:28
passthrough bathroom layer in a much more efficient expression instead of
11:30
more efficient expression instead of
11:30
more efficient expression instead of backward propagating through all of its
11:31
backward propagating through all of its
11:31
backward propagating through all of its little pieces independently
11:33
little pieces independently
11:33
little pieces independently so there's going to be exercise three
11:36
so there's going to be exercise three
11:36
so there's going to be exercise three and then in exercise four we're going to
11:38
and then in exercise four we're going to
11:38
and then in exercise four we're going to put it all together and this is the full
11:40
put it all together and this is the full
11:40
put it all together and this is the full code of training this two layer MLP and
11:42
code of training this two layer MLP and
11:42
code of training this two layer MLP and we're going to basically insert our
11:44
we're going to basically insert our
11:44
we're going to basically insert our manual back prop and we're going to take
11:46
manual back prop and we're going to take
11:46
manual back prop and we're going to take out lost it backward and you will
11:48
out lost it backward and you will
11:48
out lost it backward and you will basically see that you can get all the
11:50
basically see that you can get all the
11:50
basically see that you can get all the same results using fully your own code
11:53
same results using fully your own code
11:53
same results using fully your own code and the only thing we're using from
11:55
and the only thing we're using from
11:55
and the only thing we're using from pytorch is the torch.tensor to make the
11:59
pytorch is the torch.tensor to make the
11:59
pytorch is the torch.tensor to make the calculations efficient but otherwise you
12:01
calculations efficient but otherwise you
12:01
calculations efficient but otherwise you will understand fully what it means to
12:03
will understand fully what it means to
12:03
will understand fully what it means to forward and backward and neural net and
12:04
forward and backward and neural net and
12:04
forward and backward and neural net and train it and I think that'll be awesome
12:06
train it and I think that'll be awesome
12:06
train it and I think that'll be awesome so let's get to it
12:08
so let's get to it
12:08
so let's get to it okay so I read all the cells of this
12:10
okay so I read all the cells of this
12:10
okay so I read all the cells of this notebook all the way up to here and I'm
12:13
notebook all the way up to here and I'm
12:13
notebook all the way up to here and I'm going to erase this and I'm going to
12:14
going to erase this and I'm going to
12:14
going to erase this and I'm going to start implementing backward pass
12:15
start implementing backward pass
12:15
start implementing backward pass starting with d lock problems so we want
12:18
starting with d lock problems so we want
12:18
starting with d lock problems so we want to understand what should go here to
12:20
to understand what should go here to
12:20
to understand what should go here to calculate the gradient of the loss with
12:22
calculate the gradient of the loss with
12:22
calculate the gradient of the loss with respect to all the elements of the log
12:23
respect to all the elements of the log
12:23
respect to all the elements of the log props tensor
12:25
props tensor
12:25
props tensor now I'm going to give away the answer
12:26
now I'm going to give away the answer
12:26
now I'm going to give away the answer here but I wanted to put a quick note
12:28
here but I wanted to put a quick note
12:28
here but I wanted to put a quick note here that I think would be most
12:30
here that I think would be most
12:30
here that I think would be most pedagogically useful for you is to
12:32
pedagogically useful for you is to
12:32
pedagogically useful for you is to actually go into the description of this
12:34
actually go into the description of this
12:34
actually go into the description of this video and find the link to this Jupiter
12:36
video and find the link to this Jupiter
12:36
video and find the link to this Jupiter notebook you can find it both on GitHub
12:38
notebook you can find it both on GitHub
12:38
notebook you can find it both on GitHub but you can also find Google collab with
12:40
but you can also find Google collab with
12:40
but you can also find Google collab with it so you don't have to install anything
12:41
it so you don't have to install anything
12:41
it so you don't have to install anything you'll just go to a website on Google
12:43
you'll just go to a website on Google
12:43
you'll just go to a website on Google collab and you can try to implement
12:45
collab and you can try to implement
12:45
collab and you can try to implement these derivatives or gradients yourself
12:47
these derivatives or gradients yourself
12:47
these derivatives or gradients yourself and then if you are not able to come to
12:50
and then if you are not able to come to
12:50
and then if you are not able to come to my video and see me do it and so work in
12:53
my video and see me do it and so work in
12:53
my video and see me do it and so work in Tandem and try it first yourself and
12:55
Tandem and try it first yourself and
12:55
Tandem and try it first yourself and then see me give away the answer and I
12:57
then see me give away the answer and I
12:57
then see me give away the answer and I think that'll be most valuable to you
12:59
think that'll be most valuable to you
12:59
think that'll be most valuable to you and that's how I recommend you go
13:00
and that's how I recommend you go
13:00
and that's how I recommend you go through this lecture
13:01
through this lecture
13:01
through this lecture so we are starting here with d-log props
13:03
so we are starting here with d-log props
13:03
so we are starting here with d-log props now d-lock props will hold the
13:06
now d-lock props will hold the
13:06
now d-lock props will hold the derivative of the loss with respect to
13:08
derivative of the loss with respect to
13:08
derivative of the loss with respect to all the elements of log props
13:11
all the elements of log props
13:11
all the elements of log props what is inside log blobs the shape of
13:13
what is inside log blobs the shape of
13:13
what is inside log blobs the shape of this is 32 by 27. so it's not going to
13:18
this is 32 by 27. so it's not going to
13:18
this is 32 by 27. so it's not going to surprise you that D log props should
13:19
surprise you that D log props should
13:19
surprise you that D log props should also be an array of size 32 by 27
13:21
also be an array of size 32 by 27
13:21
also be an array of size 32 by 27 because we want the derivative loss with
13:23
because we want the derivative loss with
13:23
because we want the derivative loss with respect to all of its elements so the
13:26
respect to all of its elements so the
13:26
respect to all of its elements so the sizes of those are always going to be
13:27
sizes of those are always going to be
13:27
sizes of those are always going to be equal
13:29
equal
13:29
equal now how how does log props influence the
13:33
now how how does log props influence the
13:33
now how how does log props influence the loss okay loss is negative block probes
13:36
loss okay loss is negative block probes
13:36
loss okay loss is negative block probes indexed with range of N and YB and then
13:40
indexed with range of N and YB and then
13:40
indexed with range of N and YB and then the mean of that now just as a reminder
13:42
the mean of that now just as a reminder
13:42
the mean of that now just as a reminder YB is just a basically an array of all
13:47
YB is just a basically an array of all
13:47
YB is just a basically an array of all the correct indices
13:51
the correct indices
13:51
the correct indices um so what we're doing here is we're
13:52
um so what we're doing here is we're
13:52
um so what we're doing here is we're taking the lock props array of size 32
13:54
taking the lock props array of size 32
13:54
taking the lock props array of size 32 by 27.
13:57
by 27.
13:57
by 27. right
13:58
right
13:58
right and then we are going in every single
14:00
and then we are going in every single
14:00
and then we are going in every single row and in each row we are plugging
14:03
row and in each row we are plugging
14:03
row and in each row we are plugging plucking out the index eight and then 14
14:06
plucking out the index eight and then 14
14:06
plucking out the index eight and then 14 and 15 and so on so we're going down the
14:07
and 15 and so on so we're going down the
14:07
and 15 and so on so we're going down the rows that's the iterator range of N and
14:10
rows that's the iterator range of N and
14:10
rows that's the iterator range of N and then we are always plucking out the
14:12
then we are always plucking out the
14:12
then we are always plucking out the index of the column specified by this
14:15
index of the column specified by this
14:15
index of the column specified by this tensor YB so in the zeroth row we are
14:17
tensor YB so in the zeroth row we are
14:17
tensor YB so in the zeroth row we are taking the eighth column in the first
14:20
taking the eighth column in the first
14:20
taking the eighth column in the first row we're taking the 14th column Etc and
14:23
row we're taking the 14th column Etc and
14:23
row we're taking the 14th column Etc and so log props at this plugs out
14:26
so log props at this plugs out
14:26
so log props at this plugs out all those
14:28
all those
14:28
all those log probabilities of the correct next
14:30
log probabilities of the correct next
14:30
log probabilities of the correct next character in a sequence
14:32
character in a sequence
14:32
character in a sequence so that's what that does and the shape
14:34
so that's what that does and the shape
14:34
so that's what that does and the shape of this or the size of it is of course
14:36
of this or the size of it is of course
14:36
of this or the size of it is of course 32 because our batch size is 32.
14:40
32 because our batch size is 32.
14:40
32 because our batch size is 32. so these elements get plugged out and
14:43
so these elements get plugged out and
14:43
so these elements get plugged out and then their mean and the negative of that
14:45
then their mean and the negative of that
14:45
then their mean and the negative of that becomes loss
14:47
becomes loss
14:47
becomes loss so I always like to work with simpler
14:49
so I always like to work with simpler
14:49
so I always like to work with simpler examples to understand the numerical
14:52
examples to understand the numerical
14:52
examples to understand the numerical form of derivative what's going on here
14:55
form of derivative what's going on here
14:55
form of derivative what's going on here is once we've plucked out these examples
14:58
is once we've plucked out these examples
14:58
is once we've plucked out these examples um we're taking the mean and then the
15:00
um we're taking the mean and then the
15:00
um we're taking the mean and then the negative so the loss basically
15:02
negative so the loss basically
15:02
negative so the loss basically I can write it this way is the negative
15:04
I can write it this way is the negative
15:04
I can write it this way is the negative of say a plus b plus c
15:07
of say a plus b plus c
15:07
of say a plus b plus c and the mean of those three numbers
15:09
and the mean of those three numbers
15:09
and the mean of those three numbers would be say negative would divide three
15:11
would be say negative would divide three
15:11
would be say negative would divide three that would be how we achieve the mean of
15:13
that would be how we achieve the mean of
15:13
that would be how we achieve the mean of three numbers ABC although we actually
15:15
three numbers ABC although we actually
15:15
three numbers ABC although we actually have 32 numbers here
15:16
have 32 numbers here
15:16
have 32 numbers here and so what is basically the loss by say
15:20
and so what is basically the loss by say
15:20
and so what is basically the loss by say like d a right
15:22
like d a right
15:22
like d a right well if we simplify this expression
15:24
well if we simplify this expression
15:24
well if we simplify this expression mathematically this is negative one over
15:26
mathematically this is negative one over
15:26
mathematically this is negative one over three of A and negative plus negative
15:28
three of A and negative plus negative
15:28
three of A and negative plus negative one over three of B
15:30
one over three of B
15:30
one over three of B plus negative 1 over 3 of c and so what
15:33
plus negative 1 over 3 of c and so what
15:33
plus negative 1 over 3 of c and so what is D loss by D A it's just negative one
15:35
is D loss by D A it's just negative one
15:35
is D loss by D A it's just negative one over three
15:36
over three
15:36
over three and so you can see that if we don't just
15:38
and so you can see that if we don't just
15:38
and so you can see that if we don't just have a b and c but we have 32 numbers
15:40
have a b and c but we have 32 numbers
15:40
have a b and c but we have 32 numbers then D loss by D
15:43
then D loss by D
15:43
then D loss by D um you know every one of those numbers
15:45
um you know every one of those numbers
15:45
um you know every one of those numbers is going to be one over N More generally
15:47
is going to be one over N More generally
15:47
is going to be one over N More generally because n is the um the size of the
15:50
because n is the um the size of the
15:50
because n is the um the size of the batch 32 in this case
15:53
batch 32 in this case
15:53
batch 32 in this case so D loss by
15:55
so D loss by
15:55
so D loss by um D Lock probs is negative 1 over n
15:59
um D Lock probs is negative 1 over n
15:59
um D Lock probs is negative 1 over n in all these places
16:02
in all these places
16:02
in all these places now what about the other elements inside
16:04
now what about the other elements inside
16:04
now what about the other elements inside lock problems because lock props is
16:05
lock problems because lock props is
16:05
lock problems because lock props is large array you see that lock problems
16:07
large array you see that lock problems
16:07
large array you see that lock problems at shape is 32 by 27. but only 32 of
16:11
at shape is 32 by 27. but only 32 of
16:11
at shape is 32 by 27. but only 32 of them participate in the loss calculation
16:13
them participate in the loss calculation
16:13
them participate in the loss calculation so what's the derivative of all the
16:15
so what's the derivative of all the
16:15
so what's the derivative of all the other most of the elements that do not
16:18
other most of the elements that do not
16:18
other most of the elements that do not get plucked out here
16:20
get plucked out here
16:20
get plucked out here while their loss intuitively is zero
16:22
while their loss intuitively is zero
16:22
while their loss intuitively is zero sorry they're gradient intuitively is
16:24
sorry they're gradient intuitively is
16:24
sorry they're gradient intuitively is zero and that's because they did not
16:25
zero and that's because they did not
16:25
zero and that's because they did not participate in the loss
16:27
participate in the loss
16:27
participate in the loss so most of these numbers inside this
16:29
so most of these numbers inside this
16:29
so most of these numbers inside this tensor does not feed into the loss and
16:32
tensor does not feed into the loss and
16:32
tensor does not feed into the loss and so if we were to change these numbers
16:33
so if we were to change these numbers
16:33
so if we were to change these numbers then the loss doesn't change which is
16:36
then the loss doesn't change which is
16:36
then the loss doesn't change which is the equivalent of way of saying that the
16:38
the equivalent of way of saying that the
16:38
the equivalent of way of saying that the derivative of the loss with respect to
16:39
derivative of the loss with respect to
16:39
derivative of the loss with respect to them is zero they don't impact it
16:43
them is zero they don't impact it
16:43
them is zero they don't impact it so here's a way to implement this
16:45
so here's a way to implement this
16:45
so here's a way to implement this derivative then we start out with
16:47
derivative then we start out with
16:47
derivative then we start out with torch.zeros of shape 32 by 27 or let's
16:50
torch.zeros of shape 32 by 27 or let's
16:50
torch.zeros of shape 32 by 27 or let's just say instead of doing this because
16:52
just say instead of doing this because
16:52
just say instead of doing this because we don't want to hard code numbers let's
16:54
we don't want to hard code numbers let's
16:54
we don't want to hard code numbers let's do torch.zeros like
16:56
do torch.zeros like
16:57
do torch.zeros like block probs so basically this is going
16:59
block probs so basically this is going
16:59
block probs so basically this is going to create an array of zeros exactly in
17:00
to create an array of zeros exactly in
17:00
to create an array of zeros exactly in the shape of log probs
17:02
the shape of log probs
17:02
the shape of log probs and then we need to set the derivative
17:05
and then we need to set the derivative
17:05
and then we need to set the derivative of negative 1 over n inside exactly
17:07
of negative 1 over n inside exactly
17:07
of negative 1 over n inside exactly these locations so here's what we can do
17:09
these locations so here's what we can do
17:09
these locations so here's what we can do the lock props indexed in The Identical
17:12
the lock props indexed in The Identical
17:12
the lock props indexed in The Identical way
17:14
way
17:14
way will be just set to negative one over
17:16
will be just set to negative one over
17:16
will be just set to negative one over zero divide n
17:19
zero divide n
17:19
zero divide n right just like we derived here
17:22
right just like we derived here
17:22
right just like we derived here so now let me erase all this reasoning
17:25
so now let me erase all this reasoning
17:25
so now let me erase all this reasoning and then this is the candidate
17:27
and then this is the candidate
17:27
and then this is the candidate derivative for D log props let's
17:29
derivative for D log props let's
17:29
derivative for D log props let's uncomment the first line and check that
17:31
uncomment the first line and check that
17:31
uncomment the first line and check that this is correct
17:34
this is correct
17:34
this is correct okay so CMP ran and let's go back to CMP
17:39
okay so CMP ran and let's go back to CMP
17:39
okay so CMP ran and let's go back to CMP and you see that what it's doing is it's
17:41
and you see that what it's doing is it's
17:41
and you see that what it's doing is it's calculating if
17:42
calculating if
17:42
calculating if the calculated value by us which is DT
17:46
the calculated value by us which is DT
17:46
the calculated value by us which is DT is exactly equal to T dot grad as
17:48
is exactly equal to T dot grad as
17:48
is exactly equal to T dot grad as calculated by pi torch and then this is
17:51
calculated by pi torch and then this is
17:51
calculated by pi torch and then this is making sure that all the elements are
17:52
making sure that all the elements are
17:52
making sure that all the elements are exactly equal and then converting this
17:54
exactly equal and then converting this
17:54
exactly equal and then converting this to a single Boolean value because we
17:57
to a single Boolean value because we
17:57
to a single Boolean value because we don't want the Boolean tensor we just
17:58
don't want the Boolean tensor we just
17:58
don't want the Boolean tensor we just want to Boolean value
18:00
want to Boolean value
18:00
want to Boolean value and then here we are making sure that
18:02
and then here we are making sure that
18:02
and then here we are making sure that okay if they're not exactly equal maybe
18:04
okay if they're not exactly equal maybe
18:04
okay if they're not exactly equal maybe they are approximately equal because of
18:05
they are approximately equal because of
18:06
they are approximately equal because of some floating Point issues but they're
18:07
some floating Point issues but they're
18:07
some floating Point issues but they're very very close
18:09
very very close
18:09
very very close so here we are using torch.allclose
18:10
so here we are using torch.allclose
18:10
so here we are using torch.allclose which has a little bit of a wiggle
18:13
which has a little bit of a wiggle
18:13
which has a little bit of a wiggle available because sometimes you can get
18:15
available because sometimes you can get
18:15
available because sometimes you can get very very close but if you use a
18:17
very very close but if you use a
18:17
very very close but if you use a slightly different calculation because a
18:19
slightly different calculation because a
18:19
slightly different calculation because a floating Point arithmetic you can get a
18:22
floating Point arithmetic you can get a
18:22
floating Point arithmetic you can get a slightly different result so this is
18:24
slightly different result so this is
18:24
slightly different result so this is checking if you get an approximately
18:25
checking if you get an approximately
18:25
checking if you get an approximately close result
18:27
close result
18:27
close result and then here we are checking the
18:28
and then here we are checking the
18:28
and then here we are checking the maximum uh basically the value that has
18:31
maximum uh basically the value that has
18:31
maximum uh basically the value that has the highest difference and what is the
18:34
the highest difference and what is the
18:34
the highest difference and what is the difference in the absolute value
18:35
difference in the absolute value
18:35
difference in the absolute value difference between those two and so we
18:37
difference between those two and so we
18:37
difference between those two and so we are printing whether we have an exact
18:39
are printing whether we have an exact
18:39
are printing whether we have an exact equality an approximate equality and
18:42
equality an approximate equality and
18:42
equality an approximate equality and what is the largest difference
18:44
what is the largest difference
18:45
what is the largest difference and so here
18:46
and so here
18:46
and so here we see that we actually have exact
18:48
we see that we actually have exact
18:48
we see that we actually have exact equality and so therefore of course we
18:50
equality and so therefore of course we
18:50
equality and so therefore of course we also have an approximate equality and
18:52
also have an approximate equality and
18:52
also have an approximate equality and the maximum difference is exactly zero
18:54
the maximum difference is exactly zero
18:54
the maximum difference is exactly zero so basically our d-log props is exactly
18:57
so basically our d-log props is exactly
18:57
so basically our d-log props is exactly equal to what pytors calculated to be
19:00
equal to what pytors calculated to be
19:00
equal to what pytors calculated to be lockprops.grad in its back propagation
19:03
lockprops.grad in its back propagation
19:03
lockprops.grad in its back propagation so so far we're working pretty well okay
19:06
so so far we're working pretty well okay
19:06
so so far we're working pretty well okay so let's now continue our back
19:07
so let's now continue our back
19:07
so let's now continue our back propagation
19:08
propagation
19:08
propagation we have that lock props depends on
19:10
we have that lock props depends on
19:10
we have that lock props depends on probes through a log
19:12
probes through a log
19:12
probes through a log so all the elements of probes are being
19:14
so all the elements of probes are being
19:14
so all the elements of probes are being element wise applied log to
19:17
element wise applied log to
19:17
element wise applied log to now if we want deep props then then
19:19
now if we want deep props then then
19:19
now if we want deep props then then remember your micrograph training
19:22
remember your micrograph training
19:22
remember your micrograph training we have like a log node it takes in
19:24
we have like a log node it takes in
19:24
we have like a log node it takes in probs and creates log probs and the
19:27
probs and creates log probs and the
19:27
probs and creates log probs and the props will be the local derivative of
19:29
props will be the local derivative of
19:30
props will be the local derivative of that individual Operation Log times the
19:33
that individual Operation Log times the
19:33
that individual Operation Log times the derivative loss with respect to its
19:34
derivative loss with respect to its
19:34
derivative loss with respect to its output which in this case is D log props
19:37
output which in this case is D log props
19:37
output which in this case is D log props so what is the local derivative of this
19:39
so what is the local derivative of this
19:39
so what is the local derivative of this operation well we are taking log element
19:41
operation well we are taking log element
19:41
operation well we are taking log element wise and we can come here and we can see
19:43
wise and we can come here and we can see
19:43
wise and we can come here and we can see well from alpha is your friend that d by
19:45
well from alpha is your friend that d by
19:45
well from alpha is your friend that d by DX of log of x is just simply one of our
19:47
DX of log of x is just simply one of our
19:47
DX of log of x is just simply one of our X
19:48
X
19:48
X so therefore in this case X is problems
19:51
so therefore in this case X is problems
19:51
so therefore in this case X is problems so we have d by DX is one over X which
19:54
so we have d by DX is one over X which
19:54
so we have d by DX is one over X which is one of our probes and then this is
19:56
is one of our probes and then this is
19:56
is one of our probes and then this is the local derivative and then times we
19:58
the local derivative and then times we
19:58
the local derivative and then times we want to chain it
20:00
want to chain it
20:00
want to chain it so this is chain rule
20:01
so this is chain rule
20:01
so this is chain rule times do log props
20:03
times do log props
20:03
times do log props let me uncomment this and let me run the
20:06
let me uncomment this and let me run the
20:06
let me uncomment this and let me run the cell in place and we see that the
20:08
cell in place and we see that the
20:08
cell in place and we see that the derivative of props as we calculated
20:10
derivative of props as we calculated
20:10
derivative of props as we calculated here is exactly correct
20:12
here is exactly correct
20:12
here is exactly correct and so notice here how this works probes
20:15
and so notice here how this works probes
20:15
and so notice here how this works probes that are props is going to be inverted
20:18
that are props is going to be inverted
20:18
that are props is going to be inverted and then element was multiplied here
20:20
and then element was multiplied here
20:20
and then element was multiplied here so if your probes is very very close to
20:23
so if your probes is very very close to
20:23
so if your probes is very very close to one that means you are your network is
20:25
one that means you are your network is
20:25
one that means you are your network is currently predicting the character
20:26
currently predicting the character
20:26
currently predicting the character correctly then this will become one over
20:28
correctly then this will become one over
20:28
correctly then this will become one over one and D log probes just gets passed
20:30
one and D log probes just gets passed
20:30
one and D log probes just gets passed through
20:31
through
20:31
through but if your probabilities are
20:33
but if your probabilities are
20:33
but if your probabilities are incorrectly assigned so if the correct
20:35
incorrectly assigned so if the correct
20:35
incorrectly assigned so if the correct character here is getting a very low
20:37
character here is getting a very low
20:37
character here is getting a very low probability then 1.0 dividing by it will
20:41
probability then 1.0 dividing by it will
20:41
probability then 1.0 dividing by it will boost this
20:43
boost this
20:43
boost this and then multiply by the log props so
20:45
and then multiply by the log props so
20:45
and then multiply by the log props so basically what this line is doing
20:46
basically what this line is doing
20:46
basically what this line is doing intuitively is it's taking the examples
20:49
intuitively is it's taking the examples
20:49
intuitively is it's taking the examples that have a very low probability
20:50
that have a very low probability
20:50
that have a very low probability currently assigned and it's boosting
20:52
currently assigned and it's boosting
20:52
currently assigned and it's boosting their gradient uh you can you can look
20:55
their gradient uh you can you can look
20:55
their gradient uh you can you can look at it that way next up is Count some imp
20:59
at it that way next up is Count some imp
20:59
at it that way next up is Count some imp so we want the river of this now let me
21:02
so we want the river of this now let me
21:02
so we want the river of this now let me just pause here and kind of introduce
21:05
just pause here and kind of introduce
21:05
just pause here and kind of introduce What's Happening Here in general because
21:06
What's Happening Here in general because
21:06
What's Happening Here in general because I know it's a little bit confusing we
21:08
I know it's a little bit confusing we
21:08
I know it's a little bit confusing we have the locusts that come out of the
21:09
have the locusts that come out of the
21:09
have the locusts that come out of the neural nut here what I'm doing is I'm
21:11
neural nut here what I'm doing is I'm
21:11
neural nut here what I'm doing is I'm finding the maximum in each row and I'm
21:14
finding the maximum in each row and I'm
21:15
finding the maximum in each row and I'm subtracting it for the purposes of
21:16
subtracting it for the purposes of
21:16
subtracting it for the purposes of numerical stability and we talked about
21:18
numerical stability and we talked about
21:18
numerical stability and we talked about how if you do not do this you run
21:20
how if you do not do this you run
21:20
how if you do not do this you run numerical issues if some of the logits
21:22
numerical issues if some of the logits
21:22
numerical issues if some of the logits take on two large values because we end
21:24
take on two large values because we end
21:24
take on two large values because we end up exponentiating them
21:26
up exponentiating them
21:26
up exponentiating them so this is done just for safety
21:28
so this is done just for safety
21:28
so this is done just for safety numerically then here's the
21:30
numerically then here's the
21:30
numerically then here's the exponentiation of all the sort of like
21:32
exponentiation of all the sort of like
21:32
exponentiation of all the sort of like logits to create our accounts and then
21:35
logits to create our accounts and then
21:35
logits to create our accounts and then we want to take the some of these counts
21:38
we want to take the some of these counts
21:38
we want to take the some of these counts and normalize so that all of the probes
21:40
and normalize so that all of the probes
21:40
and normalize so that all of the probes sum to one
21:41
sum to one
21:41
sum to one now here instead of using one over count
21:43
now here instead of using one over count
21:43
now here instead of using one over count sum I use uh raised to the power of
21:46
sum I use uh raised to the power of
21:46
sum I use uh raised to the power of negative one mathematically they are
21:47
negative one mathematically they are
21:47
negative one mathematically they are identical I just found that there's
21:49
identical I just found that there's
21:49
identical I just found that there's something wrong with the pytorch
21:50
something wrong with the pytorch
21:50
something wrong with the pytorch implementation of the backward pass of
21:52
implementation of the backward pass of
21:52
implementation of the backward pass of division
21:53
division
21:53
division um and it gives like a real result but
21:55
um and it gives like a real result but
21:55
um and it gives like a real result but that doesn't happen for star star native
21:58
that doesn't happen for star star native
21:58
that doesn't happen for star star native one that's why I'm using this formula
21:59
one that's why I'm using this formula
21:59
one that's why I'm using this formula instead but basically all that's
22:01
instead but basically all that's
22:01
instead but basically all that's happening here is we got the logits
22:04
happening here is we got the logits
22:04
happening here is we got the logits we're going to exponentiate all of them
22:05
we're going to exponentiate all of them
22:05
we're going to exponentiate all of them and want to normalize the counts to
22:07
and want to normalize the counts to
22:07
and want to normalize the counts to create our probabilities it's just that
22:09
create our probabilities it's just that
22:09
create our probabilities it's just that it's happening across multiple lines
22:12
it's happening across multiple lines
22:12
it's happening across multiple lines so now
22:14
so now
22:14
so now here
22:17
we want to First Take the derivative we
22:20
we want to First Take the derivative we
22:20
we want to First Take the derivative we want to back propagate into account
22:21
want to back propagate into account
22:21
want to back propagate into account sumiv and then into counts as well
22:24
sumiv and then into counts as well
22:24
sumiv and then into counts as well so what should be the count sum M now we
22:28
so what should be the count sum M now we
22:28
so what should be the count sum M now we actually have to be careful here because
22:29
actually have to be careful here because
22:29
actually have to be careful here because we have to scrutinize and be careful
22:32
we have to scrutinize and be careful
22:32
we have to scrutinize and be careful with the shapes so counts that shape and
22:35
with the shapes so counts that shape and
22:35
with the shapes so counts that shape and then count some inverse shape
22:39
then count some inverse shape
22:39
then count some inverse shape are different
22:40
are different
22:40
are different so in particular counts as 32 by 27 but
22:43
so in particular counts as 32 by 27 but
22:43
so in particular counts as 32 by 27 but this count sum m is 32 by 1. and so in
22:47
this count sum m is 32 by 1. and so in
22:47
this count sum m is 32 by 1. and so in this multiplication here we also have an
22:49
this multiplication here we also have an
22:49
this multiplication here we also have an implicit broadcasting that pytorch will
22:52
implicit broadcasting that pytorch will
22:52
implicit broadcasting that pytorch will do because it needs to take this column
22:53
do because it needs to take this column
22:53
do because it needs to take this column tensor of 32 numbers and replicate it
22:55
tensor of 32 numbers and replicate it
22:55
tensor of 32 numbers and replicate it horizontally 27 times to align these two
22:58
horizontally 27 times to align these two
22:58
horizontally 27 times to align these two tensors so it can do an element twice
23:00
tensors so it can do an element twice
23:00
tensors so it can do an element twice multiply
23:01
multiply
23:01
multiply so really what this looks like is the
23:03
so really what this looks like is the
23:03
so really what this looks like is the following using a toy example again
23:06
following using a toy example again
23:06
following using a toy example again what we really have here is just props
23:08
what we really have here is just props
23:08
what we really have here is just props is counts times conservative so it's a C
23:10
is counts times conservative so it's a C
23:10
is counts times conservative so it's a C equals a times B
23:11
equals a times B
23:11
equals a times B but a is 3 by 3 and b is just three by
23:15
but a is 3 by 3 and b is just three by
23:15
but a is 3 by 3 and b is just three by one a column tensor and so pytorch
23:17
one a column tensor and so pytorch
23:17
one a column tensor and so pytorch internally replicated this elements of B
23:19
internally replicated this elements of B
23:19
internally replicated this elements of B and it did that across all the columns
23:22
and it did that across all the columns
23:22
and it did that across all the columns so for example B1 which is the first
23:24
so for example B1 which is the first
23:24
so for example B1 which is the first element of B would be replicated here
23:26
element of B would be replicated here
23:26
element of B would be replicated here across all the columns in this
23:27
across all the columns in this
23:27
across all the columns in this multiplication
23:29
multiplication
23:29
multiplication and now we're trying to back propagate
23:31
and now we're trying to back propagate
23:31
and now we're trying to back propagate through this operation to count some m
23:34
through this operation to count some m
23:34
through this operation to count some m so when we're calculating this
23:35
so when we're calculating this
23:35
so when we're calculating this derivative
23:37
derivative
23:37
derivative it's important to realize that these two
23:39
it's important to realize that these two
23:39
it's important to realize that these two this looks like a single operation but
23:41
this looks like a single operation but
23:41
this looks like a single operation but actually is two operations applied
23:44
actually is two operations applied
23:44
actually is two operations applied sequentially the first operation that
23:46
sequentially the first operation that
23:46
sequentially the first operation that pytorch did is it took this column
23:48
pytorch did is it took this column
23:48
pytorch did is it took this column tensor and replicated it across all the
23:52
tensor and replicated it across all the
23:52
tensor and replicated it across all the um across all the columns basically 27
23:54
um across all the columns basically 27
23:54
um across all the columns basically 27 times so that's the first operation it's
23:55
times so that's the first operation it's
23:55
times so that's the first operation it's a replication and then the second
23:57
a replication and then the second
23:57
a replication and then the second operation is the multiplication so let's
23:59
operation is the multiplication so let's
23:59
operation is the multiplication so let's first background through the
24:01
first background through the
24:01
first background through the multiplication
24:02
multiplication
24:02
multiplication if these two arrays are of the same size
24:05
if these two arrays are of the same size
24:05
if these two arrays are of the same size and we just have a and b of both of them
24:08
and we just have a and b of both of them
24:08
and we just have a and b of both of them three by three then how do we mult how
24:11
three by three then how do we mult how
24:11
three by three then how do we mult how do we back propagate through a
24:12
do we back propagate through a
24:12
do we back propagate through a multiplication so if we just have
24:14
multiplication so if we just have
24:14
multiplication so if we just have scalars and not tensors then if you have
24:16
scalars and not tensors then if you have
24:16
scalars and not tensors then if you have C equals a times B then what is uh the
24:19
C equals a times B then what is uh the
24:19
C equals a times B then what is uh the order of the of C with respect to B well
24:21
order of the of C with respect to B well
24:21
order of the of C with respect to B well it's just a and so that's the local
24:23
it's just a and so that's the local
24:23
it's just a and so that's the local derivative
24:24
derivative
24:24
derivative so here in our case undoing the
24:27
so here in our case undoing the
24:27
so here in our case undoing the multiplication and back propagating
24:29
multiplication and back propagating
24:29
multiplication and back propagating through just the multiplication itself
24:30
through just the multiplication itself
24:30
through just the multiplication itself which is element wise is going to be the
24:32
which is element wise is going to be the
24:33
which is element wise is going to be the local derivative which in this case is
24:36
local derivative which in this case is
24:36
local derivative which in this case is simply counts because counts is the a
24:40
simply counts because counts is the a
24:40
simply counts because counts is the a so this is the local derivative and then
24:42
so this is the local derivative and then
24:42
so this is the local derivative and then times because the chain rule D props
24:46
times because the chain rule D props
24:46
times because the chain rule D props so this here is the derivative or the
24:48
so this here is the derivative or the
24:48
so this here is the derivative or the gradient but with respect to replicated
24:50
gradient but with respect to replicated
24:50
gradient but with respect to replicated B
24:52
B
24:52
B but we don't have a replicated B we just
24:54
but we don't have a replicated B we just
24:54
but we don't have a replicated B we just have a single B column so how do we now
24:56
have a single B column so how do we now
24:56
have a single B column so how do we now back propagate through the replication
24:59
back propagate through the replication
24:59
back propagate through the replication and intuitively this B1 is the same
25:02
and intuitively this B1 is the same
25:02
and intuitively this B1 is the same variable and it's just reused multiple
25:04
variable and it's just reused multiple
25:04
variable and it's just reused multiple times
25:04
times
25:04
times and so you can look at it
25:07
and so you can look at it
25:07
and so you can look at it as being equivalent to a case we've
25:09
as being equivalent to a case we've
25:09
as being equivalent to a case we've encountered in micrograd
25:10
encountered in micrograd
25:10
encountered in micrograd and so here I'm just pulling out a
25:12
and so here I'm just pulling out a
25:12
and so here I'm just pulling out a random graph we used in micrograd we had
25:14
random graph we used in micrograd we had
25:14
random graph we used in micrograd we had an example where a single node
25:17
an example where a single node
25:17
an example where a single node has its output feeding into two branches
25:19
has its output feeding into two branches
25:19
has its output feeding into two branches of basically the graph until the last
25:22
of basically the graph until the last
25:22
of basically the graph until the last function and we're talking about how the
25:25
function and we're talking about how the
25:25
function and we're talking about how the correct thing to do in the backward pass
25:26
correct thing to do in the backward pass
25:26
correct thing to do in the backward pass is we need to sum all the gradients that
25:29
is we need to sum all the gradients that
25:29
is we need to sum all the gradients that arrive at any one node so across these
25:31
arrive at any one node so across these
25:31
arrive at any one node so across these different branches the gradients would
25:33
different branches the gradients would
25:33
different branches the gradients would sum
25:34
sum
25:34
sum so if a node is used multiple times the
25:37
so if a node is used multiple times the
25:37
so if a node is used multiple times the gradients for all of its uses sum during
25:39
gradients for all of its uses sum during
25:39
gradients for all of its uses sum during back propagation
25:41
back propagation
25:41
back propagation so here B1 is used multiple times in all
25:44
so here B1 is used multiple times in all
25:44
so here B1 is used multiple times in all these columns and therefore the right
25:45
these columns and therefore the right
25:45
these columns and therefore the right thing to do here is to sum
25:48
thing to do here is to sum
25:48
thing to do here is to sum horizontally across all the rows so I'm
25:51
horizontally across all the rows so I'm
25:51
horizontally across all the rows so I'm going to sum in
25:52
going to sum in
25:52
going to sum in Dimension one but we want to retain this
25:55
Dimension one but we want to retain this
25:55
Dimension one but we want to retain this Dimension so that the uh so that counts
25:58
Dimension so that the uh so that counts
25:58
Dimension so that the uh so that counts some end and its gradient are going to
26:00
some end and its gradient are going to
26:00
some end and its gradient are going to be exactly the same shape so we want to
26:02
be exactly the same shape so we want to
26:02
be exactly the same shape so we want to make sure that we keep them as true so
26:04
make sure that we keep them as true so
26:04
make sure that we keep them as true so we don't lose this dimension and this
26:07
we don't lose this dimension and this
26:07
we don't lose this dimension and this will make the count sum M be exactly
26:08
will make the count sum M be exactly
26:08
will make the count sum M be exactly shape 32 by 1.
26:11
shape 32 by 1.
26:11
shape 32 by 1. so revealing this comparison as well and
26:14
so revealing this comparison as well and
26:14
so revealing this comparison as well and running this we see that we get an exact
26:17
running this we see that we get an exact
26:17
running this we see that we get an exact match
26:18
match
26:18
match so this derivative is exactly correct
26:22
so this derivative is exactly correct
26:22
so this derivative is exactly correct and let me erase
26:24
and let me erase
26:24
and let me erase this now let's also back propagate into
26:26
this now let's also back propagate into
26:26
this now let's also back propagate into counts which is the other variable here
26:29
counts which is the other variable here
26:29
counts which is the other variable here to create probes so from props to count
26:32
to create probes so from props to count
26:32
to create probes so from props to count some INF we just did that let's go into
26:33
some INF we just did that let's go into
26:33
some INF we just did that let's go into counts as well
26:35
counts as well
26:35
counts as well so decounts will be
26:39
the chances are a so DC by d a is just B
26:43
the chances are a so DC by d a is just B
26:43
the chances are a so DC by d a is just B so therefore it's count summative
26:47
so therefore it's count summative
26:47
so therefore it's count summative um and then times chain rule the props
26:51
um and then times chain rule the props
26:51
um and then times chain rule the props now councilman is three two by One D
26:54
now councilman is three two by One D
26:54
now councilman is three two by One D probs is 32 by 27.
26:57
probs is 32 by 27.
26:57
probs is 32 by 27. so
26:59
so
26:59
so um those will broadcast fine and will
27:02
um those will broadcast fine and will
27:02
um those will broadcast fine and will give us decounts there's no additional
27:04
give us decounts there's no additional
27:04
give us decounts there's no additional summation required here
27:06
summation required here
27:06
summation required here um there will be a broadcasting that
27:08
um there will be a broadcasting that
27:08
um there will be a broadcasting that happens in this multiply here because
27:11
happens in this multiply here because
27:11
happens in this multiply here because count some M needs to be replicated
27:12
count some M needs to be replicated
27:12
count some M needs to be replicated again to correctly multiply D props but
27:16
again to correctly multiply D props but
27:16
again to correctly multiply D props but that's going to give the correct result
27:18
that's going to give the correct result
27:18
that's going to give the correct result so as far as the single operation is
27:20
so as far as the single operation is
27:20
so as far as the single operation is concerned so we back probably go from
27:23
concerned so we back probably go from
27:23
concerned so we back probably go from props to counts but we can't actually
27:25
props to counts but we can't actually
27:25
props to counts but we can't actually check the derivative counts uh I have it
27:29
check the derivative counts uh I have it
27:29
check the derivative counts uh I have it much later on and the reason for that is
27:31
much later on and the reason for that is
27:31
much later on and the reason for that is because count sum in depends on counts
27:34
because count sum in depends on counts
27:34
because count sum in depends on counts and so there's a second Branch here that
27:36
and so there's a second Branch here that
27:36
and so there's a second Branch here that we have to finish because can't summon
27:38
we have to finish because can't summon
27:38
we have to finish because can't summon back propagates into account sum and
27:40
back propagates into account sum and
27:40
back propagates into account sum and count sum will buy properly into counts
27:42
count sum will buy properly into counts
27:42
count sum will buy properly into counts and so counts is a node that is being
27:44
and so counts is a node that is being
27:44
and so counts is a node that is being used twice it's used right here in two
27:46
used twice it's used right here in two
27:46
used twice it's used right here in two props and it goes through this other
27:48
props and it goes through this other
27:48
props and it goes through this other Branch through count summative
27:50
Branch through count summative
27:50
Branch through count summative so even though we've calculated the
27:52
so even though we've calculated the
27:52
so even though we've calculated the first contribution of it we still have
27:54
first contribution of it we still have
27:54
first contribution of it we still have to calculate the second contribution of
27:55
to calculate the second contribution of
27:55
to calculate the second contribution of it later
27:56
it later
27:57
it later okay so we're continuing with this
27:58
okay so we're continuing with this
27:58
okay so we're continuing with this Branch we have the derivative for count
28:00
Branch we have the derivative for count
28:00
Branch we have the derivative for count sum if now we want the derivative of
28:02
sum if now we want the derivative of
28:02
sum if now we want the derivative of count sum so D count sum equals what is
28:05
count sum so D count sum equals what is
28:05
count sum so D count sum equals what is the local derivative of this operation
28:07
the local derivative of this operation
28:07
the local derivative of this operation so this is basically an element wise one
28:09
so this is basically an element wise one
28:09
so this is basically an element wise one over counts sum
28:11
over counts sum
28:11
over counts sum so count sum raised to the power of
28:13
so count sum raised to the power of
28:13
so count sum raised to the power of negative one is the same as one over
28:15
negative one is the same as one over
28:15
negative one is the same as one over count sum if we go to all from alpha we
28:17
count sum if we go to all from alpha we
28:17
count sum if we go to all from alpha we see that x to the negative one D by D by
28:20
see that x to the negative one D by D by
28:20
see that x to the negative one D by D by D by DX of it is basically Negative X to
28:23
D by DX of it is basically Negative X to
28:23
D by DX of it is basically Negative X to the negative 2. right one negative one
28:25
the negative 2. right one negative one
28:25
the negative 2. right one negative one over squared is the same as Negative X
28:27
over squared is the same as Negative X
28:27
over squared is the same as Negative X to the negative two
28:29
to the negative two
28:29
to the negative two so D count sum here will be local
28:32
so D count sum here will be local
28:32
so D count sum here will be local derivative is going to be negative
28:35
derivative is going to be negative
28:35
derivative is going to be negative um
28:36
um
28:36
um counts sum to the negative two that's
28:39
counts sum to the negative two that's
28:39
counts sum to the negative two that's the local derivative times chain rule
28:41
the local derivative times chain rule
28:41
the local derivative times chain rule which is D count sum in
28:46
so that's D count sum
28:49
so that's D count sum
28:49
so that's D count sum let's uncomment this and check that I am
28:51
let's uncomment this and check that I am
28:51
let's uncomment this and check that I am correct okay so we have perfect equality
28:55
correct okay so we have perfect equality
28:55
correct okay so we have perfect equality and there's no sketchiness going on here
28:58
and there's no sketchiness going on here
28:58
and there's no sketchiness going on here with any shapes because these are of the
28:59
with any shapes because these are of the
28:59
with any shapes because these are of the same shape okay next up we want to back
29:02
same shape okay next up we want to back
29:02
same shape okay next up we want to back propagate through this line we have that
29:04
propagate through this line we have that
29:04
propagate through this line we have that count sum it's count.sum along the rows
29:07
count sum it's count.sum along the rows
29:07
count sum it's count.sum along the rows so I wrote out
29:09
so I wrote out
29:09
so I wrote out um some help here we have to keep in
29:11
um some help here we have to keep in
29:11
um some help here we have to keep in mind that counts of course is 32 by 27
29:13
mind that counts of course is 32 by 27
29:13
mind that counts of course is 32 by 27 and count sum is 32 by 1. so in this
29:17
and count sum is 32 by 1. so in this
29:17
and count sum is 32 by 1. so in this back propagation we need to take this
29:19
back propagation we need to take this
29:19
back propagation we need to take this column of derivatives and transform it
29:22
column of derivatives and transform it
29:22
column of derivatives and transform it into a array of derivatives
29:24
into a array of derivatives
29:24
into a array of derivatives two-dimensional array
29:26
two-dimensional array
29:26
two-dimensional array so what is this operation doing we're
29:28
so what is this operation doing we're
29:28
so what is this operation doing we're taking in some kind of an input like say
29:31
taking in some kind of an input like say
29:31
taking in some kind of an input like say a three by three Matrix a and we are
29:32
a three by three Matrix a and we are
29:32
a three by three Matrix a and we are summing up the rows into a column tells
29:35
summing up the rows into a column tells
29:36
summing up the rows into a column tells her B1 b2b3 that is basically this
29:39
her B1 b2b3 that is basically this
29:39
her B1 b2b3 that is basically this so now we have the derivatives of the
29:41
so now we have the derivatives of the
29:41
so now we have the derivatives of the loss with respect to B all the elements
29:44
loss with respect to B all the elements
29:44
loss with respect to B all the elements of B
29:45
of B
29:45
of B and now we want to derivative loss with
29:47
and now we want to derivative loss with
29:47
and now we want to derivative loss with respect to all these little A's
29:49
respect to all these little A's
29:49
respect to all these little A's so how do the B's depend on the ace is
29:52
so how do the B's depend on the ace is
29:52
so how do the B's depend on the ace is basically what we're after what is the
29:54
basically what we're after what is the
29:54
basically what we're after what is the local derivative of this operation
29:56
local derivative of this operation
29:56
local derivative of this operation well we can see here that B1 only
29:58
well we can see here that B1 only
29:58
well we can see here that B1 only depends on these elements here the
30:01
depends on these elements here the
30:01
depends on these elements here the derivative of B1 with respect to all of
30:03
derivative of B1 with respect to all of
30:03
derivative of B1 with respect to all of these elements down here is zero but for
30:06
these elements down here is zero but for
30:06
these elements down here is zero but for these elements here like a11 a12 Etc the
30:09
these elements here like a11 a12 Etc the
30:09
these elements here like a11 a12 Etc the local derivative is one right so DB 1 by
30:13
local derivative is one right so DB 1 by
30:13
local derivative is one right so DB 1 by D A 1 1 for example is one so it's one
30:16
D A 1 1 for example is one so it's one
30:16
D A 1 1 for example is one so it's one one and one
30:17
one and one
30:18
one and one so when we have the derivative of loss
30:19
so when we have the derivative of loss
30:19
so when we have the derivative of loss with respect to B1
30:21
with respect to B1
30:21
with respect to B1 did a local derivative of B1 with
30:23
did a local derivative of B1 with
30:23
did a local derivative of B1 with respect to these inputs is zeros here
30:25
respect to these inputs is zeros here
30:25
respect to these inputs is zeros here but it's one on these guys
30:27
but it's one on these guys
30:27
but it's one on these guys so in the chain rule
30:29
so in the chain rule
30:29
so in the chain rule we have the local derivative uh times
30:32
we have the local derivative uh times
30:32
we have the local derivative uh times sort of the derivative of B1 and so
30:35
sort of the derivative of B1 and so
30:35
sort of the derivative of B1 and so because the local derivative is one on
30:37
because the local derivative is one on
30:37
because the local derivative is one on these three elements the look of them
30:39
these three elements the look of them
30:39
these three elements the look of them are multiplying the derivative of B1
30:41
are multiplying the derivative of B1
30:41
are multiplying the derivative of B1 will just be the derivative of B1 and so
30:45
will just be the derivative of B1 and so
30:45
will just be the derivative of B1 and so you can look at it as a router basically
30:47
you can look at it as a router basically
30:47
you can look at it as a router basically an addition is a router of gradient
30:50
an addition is a router of gradient
30:50
an addition is a router of gradient whatever gradient comes from above it
30:52
whatever gradient comes from above it
30:52
whatever gradient comes from above it just gets routed equally to all the
30:53
just gets routed equally to all the
30:53
just gets routed equally to all the elements that participate in that
30:55
elements that participate in that
30:55
elements that participate in that addition
30:56
addition
30:56
addition so in this case the derivative of B1
30:58
so in this case the derivative of B1
30:58
so in this case the derivative of B1 will just flow equally to the derivative
31:00
will just flow equally to the derivative
31:00
will just flow equally to the derivative of a11 a12 and a13
31:03
of a11 a12 and a13
31:03
of a11 a12 and a13 . so if we have a derivative of all the
31:05
. so if we have a derivative of all the
31:05
. so if we have a derivative of all the elements of B and in this column tensor
31:07
elements of B and in this column tensor
31:07
elements of B and in this column tensor which is D counts sum that we've
31:10
which is D counts sum that we've
31:10
which is D counts sum that we've calculated just now
31:11
calculated just now
31:11
calculated just now we basically see that what that amounts
31:14
we basically see that what that amounts
31:14
we basically see that what that amounts to is all of these are now flowing to
31:17
to is all of these are now flowing to
31:17
to is all of these are now flowing to all these elements of a and they're
31:19
all these elements of a and they're
31:19
all these elements of a and they're doing that horizontally
31:20
doing that horizontally
31:21
doing that horizontally so basically what we want is we want to
31:22
so basically what we want is we want to
31:22
so basically what we want is we want to take the decount sum of size 30 by 1 and
31:26
take the decount sum of size 30 by 1 and
31:26
take the decount sum of size 30 by 1 and we just want to replicate it 27 times
31:28
we just want to replicate it 27 times
31:28
we just want to replicate it 27 times horizontally to create 32 by 27 array
31:32
horizontally to create 32 by 27 array
31:32
horizontally to create 32 by 27 array so there's many ways to implement this
31:33
so there's many ways to implement this
31:33
so there's many ways to implement this operation you could of course just
31:35
operation you could of course just
31:35
operation you could of course just replicate the tensor but I think maybe
31:37
replicate the tensor but I think maybe
31:37
replicate the tensor but I think maybe one clean one is that the counts is
31:40
one clean one is that the counts is
31:40
one clean one is that the counts is simply torch dot once like
31:43
simply torch dot once like
31:43
simply torch dot once like so just an two-dimensional arrays of
31:45
so just an two-dimensional arrays of
31:45
so just an two-dimensional arrays of ones in the shape of counts so 32 by 27
31:49
ones in the shape of counts so 32 by 27
31:49
ones in the shape of counts so 32 by 27 times D counts sum so this way we're
31:53
times D counts sum so this way we're
31:53
times D counts sum so this way we're letting the broadcasting here basically
31:56
letting the broadcasting here basically
31:56
letting the broadcasting here basically implement the replication you can look
31:58
implement the replication you can look
31:58
implement the replication you can look at it that way
31:59
at it that way
31:59
at it that way but then we have to also be careful
32:02
but then we have to also be careful
32:02
but then we have to also be careful because decounts was already calculated
32:05
because decounts was already calculated
32:05
because decounts was already calculated we calculated earlier here and that was
32:08
we calculated earlier here and that was
32:08
we calculated earlier here and that was just the first branch and we're now
32:09
just the first branch and we're now
32:09
just the first branch and we're now finishing the second Branch so we need
32:11
finishing the second Branch so we need
32:11
finishing the second Branch so we need to make sure that these gradients add so
32:13
to make sure that these gradients add so
32:13
to make sure that these gradients add so plus equals
32:14
plus equals
32:14
plus equals and then here
32:16
and then here
32:16
and then here um let's comment out the comparison and
32:20
um let's comment out the comparison and
32:20
um let's comment out the comparison and let's make sure crossing fingers that we
32:23
let's make sure crossing fingers that we
32:23
let's make sure crossing fingers that we have the correct result so pytorch
32:25
have the correct result so pytorch
32:25
have the correct result so pytorch agrees with us on this gradient as well
32:28
agrees with us on this gradient as well
32:28
agrees with us on this gradient as well okay hopefully we're getting a hang of
32:29
okay hopefully we're getting a hang of
32:29
okay hopefully we're getting a hang of this now counts as an element-wise X of
32:32
this now counts as an element-wise X of
32:32
this now counts as an element-wise X of Norm legits so now we want D Norm logits
32:36
Norm legits so now we want D Norm logits
32:36
Norm legits so now we want D Norm logits and because it's an element price
32:38
and because it's an element price
32:38
and because it's an element price operation everything is very simple what
32:40
operation everything is very simple what
32:40
operation everything is very simple what is the local derivative of e to the X
32:41
is the local derivative of e to the X
32:41
is the local derivative of e to the X it's famously just e to the x so this is
32:45
it's famously just e to the x so this is
32:45
it's famously just e to the x so this is the local derivative
32:48
that is the local derivative now we
32:50
that is the local derivative now we
32:50
that is the local derivative now we already calculated it and it's inside
32:51
already calculated it and it's inside
32:51
already calculated it and it's inside counts so we may as well potentially
32:53
counts so we may as well potentially
32:53
counts so we may as well potentially just reuse counts that is the local
32:55
just reuse counts that is the local
32:55
just reuse counts that is the local derivative
32:56
derivative
32:56
derivative times uh D counts
33:01
funny as that looks constant decount is
33:04
funny as that looks constant decount is
33:04
funny as that looks constant decount is derivative on the normal objects and now
33:07
derivative on the normal objects and now
33:07
derivative on the normal objects and now let's erase this and let's verify and it
33:10
let's erase this and let's verify and it
33:10
let's erase this and let's verify and it looks good
33:12
so that's uh normal agents
33:14
so that's uh normal agents
33:14
so that's uh normal agents okay so we are here on this line now the
33:17
okay so we are here on this line now the
33:17
okay so we are here on this line now the normal objects
33:18
normal objects
33:18
normal objects we have that and we're trying to
33:20
we have that and we're trying to
33:20
we have that and we're trying to calculate the logits and deloget Maxes
33:22
calculate the logits and deloget Maxes
33:22
calculate the logits and deloget Maxes so back propagating through this line
33:25
so back propagating through this line
33:25
so back propagating through this line now we have to be careful here because
33:26
now we have to be careful here because
33:26
now we have to be careful here because the shapes again are not the same and so
33:29
the shapes again are not the same and so
33:29
the shapes again are not the same and so there's an implicit broadcasting
33:30
there's an implicit broadcasting
33:30
there's an implicit broadcasting Happening Here
33:32
Happening Here
33:32
Happening Here so normal jits has this shape 32 by 27
33:34
so normal jits has this shape 32 by 27
33:34
so normal jits has this shape 32 by 27 logist does as well but logit Maxis is
33:37
logist does as well but logit Maxis is
33:37
logist does as well but logit Maxis is only 32 by one so there's a broadcasting
33:40
only 32 by one so there's a broadcasting
33:40
only 32 by one so there's a broadcasting here in the minus
33:42
here in the minus
33:42
here in the minus now here I try to sort of write out a
33:45
now here I try to sort of write out a
33:45
now here I try to sort of write out a two example again we basically have that
33:48
two example again we basically have that
33:48
two example again we basically have that this is our C equals a minus B
33:50
this is our C equals a minus B
33:50
this is our C equals a minus B and we see that because of the shape
33:52
and we see that because of the shape
33:52
and we see that because of the shape these are three by three but this one is
33:53
these are three by three but this one is
33:54
these are three by three but this one is just a column
33:55
just a column
33:55
just a column and so for example every element of C we
33:57
and so for example every element of C we
33:57
and so for example every element of C we have to look at how it uh came to be and
34:00
have to look at how it uh came to be and
34:00
have to look at how it uh came to be and every element of C is just the
34:01
every element of C is just the
34:01
every element of C is just the corresponding element of a minus uh
34:04
corresponding element of a minus uh
34:04
corresponding element of a minus uh basically that associated b
34:08
basically that associated b
34:08
basically that associated b so it's very clear now that the
34:10
so it's very clear now that the
34:10
so it's very clear now that the derivatives of every one of these c's
34:13
derivatives of every one of these c's
34:13
derivatives of every one of these c's with respect to their inputs are one for
34:16
with respect to their inputs are one for
34:16
with respect to their inputs are one for the corresponding a
34:18
the corresponding a
34:18
the corresponding a and it's a negative one for the
34:20
and it's a negative one for the
34:20
and it's a negative one for the corresponding B
34:22
corresponding B
34:22
corresponding B and so therefore
34:23
and so therefore
34:24
and so therefore um
34:25
um
34:25
um the derivatives on the C will flow
34:27
the derivatives on the C will flow
34:27
the derivatives on the C will flow equally to the corresponding Ace and
34:30
equally to the corresponding Ace and
34:30
equally to the corresponding Ace and then also to the corresponding base but
34:33
then also to the corresponding base but
34:33
then also to the corresponding base but then in addition to that the B's are
34:35
then in addition to that the B's are
34:35
then in addition to that the B's are broadcast so we'll have to do the
34:36
broadcast so we'll have to do the
34:36
broadcast so we'll have to do the additional sum just like we did before
34:39
additional sum just like we did before
34:39
additional sum just like we did before and of course the derivatives for B's
34:41
and of course the derivatives for B's
34:41
and of course the derivatives for B's will undergo a minus because the local
34:43
will undergo a minus because the local
34:43
will undergo a minus because the local derivative here is uh negative one
34:46
derivative here is uh negative one
34:46
derivative here is uh negative one so DC three two by D B3 is negative one
34:50
so DC three two by D B3 is negative one
34:50
so DC three two by D B3 is negative one so let's just Implement that basically
34:52
so let's just Implement that basically
34:52
so let's just Implement that basically delugits will be uh exactly copying the
34:56
delugits will be uh exactly copying the
34:56
delugits will be uh exactly copying the derivative on normal objects
34:58
derivative on normal objects
34:58
derivative on normal objects so
34:59
so
34:59
so delugits equals the norm logits and I'll
35:03
delugits equals the norm logits and I'll
35:03
delugits equals the norm logits and I'll do a DOT clone for safety so we're just
35:05
do a DOT clone for safety so we're just
35:05
do a DOT clone for safety so we're just making a copy
35:06
making a copy
35:06
making a copy and then we have that the loaded Maxis
35:09
and then we have that the loaded Maxis
35:09
and then we have that the loaded Maxis will be the negative of the non-legits
35:13
will be the negative of the non-legits
35:13
will be the negative of the non-legits because of the negative sign
35:15
because of the negative sign
35:15
because of the negative sign and then we have to be careful because
35:17
and then we have to be careful because
35:17
and then we have to be careful because logic Maxis is a column
35:20
logic Maxis is a column
35:20
logic Maxis is a column and so just like we saw before because
35:23
and so just like we saw before because
35:23
and so just like we saw before because we keep replicating the same elements
35:26
we keep replicating the same elements
35:26
we keep replicating the same elements across all the columns
35:28
across all the columns
35:28
across all the columns then in the backward pass because we
35:31
then in the backward pass because we
35:31
then in the backward pass because we keep reusing this these are all just
35:32
keep reusing this these are all just
35:33
keep reusing this these are all just like separate branches of use of that
35:35
like separate branches of use of that
35:35
like separate branches of use of that one variable and so therefore we have to
35:37
one variable and so therefore we have to
35:37
one variable and so therefore we have to do a Sum along one would keep them
35:39
do a Sum along one would keep them
35:39
do a Sum along one would keep them equals true so that we don't destroy
35:42
equals true so that we don't destroy
35:42
equals true so that we don't destroy this dimension
35:43
this dimension
35:43
this dimension and then the logic Maxes will be the
35:45
and then the logic Maxes will be the
35:45
and then the logic Maxes will be the same shape now we have to be careful
35:47
same shape now we have to be careful
35:47
same shape now we have to be careful because this deloaches is not the final
35:49
because this deloaches is not the final
35:49
because this deloaches is not the final deloaches and that's because not only do
35:52
deloaches and that's because not only do
35:52
deloaches and that's because not only do we get gradient signal into logits
35:54
we get gradient signal into logits
35:54
we get gradient signal into logits through here but the logic Maxes as a
35:56
through here but the logic Maxes as a
35:56
through here but the logic Maxes as a function of logits and that's a second
35:58
function of logits and that's a second
35:58
function of logits and that's a second Branch into logits so this is not yet
36:01
Branch into logits so this is not yet
36:01
Branch into logits so this is not yet our final derivative for logits we will
36:03
our final derivative for logits we will
36:03
our final derivative for logits we will come back later for the second branch
36:05
come back later for the second branch
36:05
come back later for the second branch for now the logic Maxis is the final
36:07
for now the logic Maxis is the final
36:07
for now the logic Maxis is the final derivative so let me uncomment this CMP
36:10
derivative so let me uncomment this CMP
36:10
derivative so let me uncomment this CMP here and let's just run this
36:12
here and let's just run this
36:12
here and let's just run this and logit Maxes hit by torch agrees with
36:15
and logit Maxes hit by torch agrees with
36:15
and logit Maxes hit by torch agrees with us
36:16
us
36:16
us so that was the derivative into through
36:19
so that was the derivative into through
36:19
so that was the derivative into through this line
36:21
this line
36:21
this line now before we move on I want to pause
36:22
now before we move on I want to pause
36:22
now before we move on I want to pause here briefly and I want to look at these
36:24
here briefly and I want to look at these
36:24
here briefly and I want to look at these logic Maxes and especially their
36:26
logic Maxes and especially their
36:26
logic Maxes and especially their gradients
36:27
gradients
36:27
gradients we've talked previously in the previous
36:28
we've talked previously in the previous
36:28
we've talked previously in the previous lecture that the only reason we're doing
36:31
lecture that the only reason we're doing
36:31
lecture that the only reason we're doing this is for the numerical stability of
36:33
this is for the numerical stability of
36:33
this is for the numerical stability of the softmax that we are implementing
36:34
the softmax that we are implementing
36:34
the softmax that we are implementing here and we talked about how if you take
36:37
here and we talked about how if you take
36:37
here and we talked about how if you take these logents for any one of these
36:39
these logents for any one of these
36:39
these logents for any one of these examples so one row of this logit's
36:41
examples so one row of this logit's
36:41
examples so one row of this logit's tensor if you add or subtract any value
36:44
tensor if you add or subtract any value
36:44
tensor if you add or subtract any value equally to all the elements then the
36:47
equally to all the elements then the
36:47
equally to all the elements then the value of the probes will be unchanged
36:49
value of the probes will be unchanged
36:49
value of the probes will be unchanged you're not changing soft Max the only
36:51
you're not changing soft Max the only
36:51
you're not changing soft Max the only thing that this is doing is it's making
36:53
thing that this is doing is it's making
36:53
thing that this is doing is it's making sure that X doesn't overflow and the
36:55
sure that X doesn't overflow and the
36:55
sure that X doesn't overflow and the reason we're using a Max is because then
36:57
reason we're using a Max is because then
36:57
reason we're using a Max is because then we are guaranteed that each row of
36:58
we are guaranteed that each row of
36:58
we are guaranteed that each row of logits the highest number is zero and so
37:01
logits the highest number is zero and so
37:01
logits the highest number is zero and so this will be safe
37:03
this will be safe
37:03
this will be safe and so
37:05
and so
37:05
and so um
37:06
um
37:06
um basically what that has repercussions
37:09
basically what that has repercussions
37:09
basically what that has repercussions if it is the case that changing logit
37:11
if it is the case that changing logit
37:11
if it is the case that changing logit Maxis does not change the props and
37:13
Maxis does not change the props and
37:13
Maxis does not change the props and therefore there's not change the loss
37:15
therefore there's not change the loss
37:15
therefore there's not change the loss then the gradient on logic masses should
37:17
then the gradient on logic masses should
37:17
then the gradient on logic masses should be zero right because saying those two
37:20
be zero right because saying those two
37:20
be zero right because saying those two things is the same
37:21
things is the same
37:21
things is the same so indeed we hope that this is very very
37:23
so indeed we hope that this is very very
37:23
so indeed we hope that this is very very small numbers so indeed we hope this is
37:25
small numbers so indeed we hope this is
37:25
small numbers so indeed we hope this is zero now because of floating Point uh
37:28
zero now because of floating Point uh
37:28
zero now because of floating Point uh sort of wonkiness
37:30
sort of wonkiness
37:30
sort of wonkiness um this doesn't come out exactly zero
37:31
um this doesn't come out exactly zero
37:31
um this doesn't come out exactly zero only in some of the rows it does but we
37:33
only in some of the rows it does but we
37:33
only in some of the rows it does but we get extremely small values like one e
37:35
get extremely small values like one e
37:35
get extremely small values like one e negative nine or ten and so this is
37:37
negative nine or ten and so this is
37:37
negative nine or ten and so this is telling us that the values of loaded
37:39
telling us that the values of loaded
37:39
telling us that the values of loaded Maxes are not impacting the loss as they
37:42
Maxes are not impacting the loss as they
37:42
Maxes are not impacting the loss as they shouldn't
37:43
shouldn't
37:43
shouldn't it feels kind of weird to back propagate
37:44
it feels kind of weird to back propagate
37:44
it feels kind of weird to back propagate through this branch honestly because
37:47
through this branch honestly because
37:48
through this branch honestly because if you have any implementation of like f
37:50
if you have any implementation of like f
37:50
if you have any implementation of like f dot cross entropy and pytorch and you
37:52
dot cross entropy and pytorch and you
37:52
dot cross entropy and pytorch and you you block together all these elements
37:54
you block together all these elements
37:54
you block together all these elements and you're not doing the back
37:54
and you're not doing the back
37:54
and you're not doing the back propagation piece by piece then you
37:57
propagation piece by piece then you
37:57
propagation piece by piece then you would probably assume that the
37:59
would probably assume that the
37:59
would probably assume that the derivative through here is exactly zero
38:01
derivative through here is exactly zero
38:01
derivative through here is exactly zero uh so you would be sort of
38:03
uh so you would be sort of
38:03
uh so you would be sort of um skipping this branch because it's
38:07
um skipping this branch because it's
38:07
um skipping this branch because it's only done for numerical stability but
38:09
only done for numerical stability but
38:09
only done for numerical stability but it's interesting to see that even if you
38:10
it's interesting to see that even if you
38:10
it's interesting to see that even if you break up everything into the full atoms
38:13
break up everything into the full atoms
38:13
break up everything into the full atoms and you still do the computation as
38:14
and you still do the computation as
38:14
and you still do the computation as you'd like with respect to numerical
38:16
you'd like with respect to numerical
38:16
you'd like with respect to numerical stability uh the correct thing happens
38:17
stability uh the correct thing happens
38:17
stability uh the correct thing happens and you still get a very very small
38:20
and you still get a very very small
38:20
and you still get a very very small gradients here
38:21
gradients here
38:21
gradients here um basically reflecting the fact that
38:23
um basically reflecting the fact that
38:23
um basically reflecting the fact that the values of these do not matter with
38:26
the values of these do not matter with
38:26
the values of these do not matter with respect to the final loss
38:27
respect to the final loss
38:27
respect to the final loss okay so let's now continue back
38:29
okay so let's now continue back
38:29
okay so let's now continue back propagation through this line here we've
38:31
propagation through this line here we've
38:31
propagation through this line here we've just calculated the logit Maxis and now
38:33
just calculated the logit Maxis and now
38:33
just calculated the logit Maxis and now we want to back prop into logits through
38:35
we want to back prop into logits through
38:35
we want to back prop into logits through this second branch
38:36
this second branch
38:36
this second branch now here of course we took legits and we
38:38
now here of course we took legits and we
38:38
now here of course we took legits and we took the max along all the rows and then
38:41
took the max along all the rows and then
38:41
took the max along all the rows and then we looked at its values here now the way
38:43
we looked at its values here now the way
38:43
we looked at its values here now the way this works is that in pytorch
38:47
this works is that in pytorch
38:47
this works is that in pytorch this thing here
38:49
this thing here
38:49
this thing here the max returns both the values and it
38:52
the max returns both the values and it
38:52
the max returns both the values and it Returns the indices at which those
38:53
Returns the indices at which those
38:53
Returns the indices at which those values to count the maximum value
38:55
values to count the maximum value
38:55
values to count the maximum value now in the forward pass we only used
38:57
now in the forward pass we only used
38:57
now in the forward pass we only used values because that's all we needed but
39:00
values because that's all we needed but
39:00
values because that's all we needed but in the backward pass it's extremely
39:01
in the backward pass it's extremely
39:01
in the backward pass it's extremely useful to know about where those maximum
39:04
useful to know about where those maximum
39:04
useful to know about where those maximum values occurred and we have the indices
39:06
values occurred and we have the indices
39:06
values occurred and we have the indices at which they occurred and this will of
39:08
at which they occurred and this will of
39:08
at which they occurred and this will of course helps us to help us do the back
39:10
course helps us to help us do the back
39:10
course helps us to help us do the back propagation because what should the
39:12
propagation because what should the
39:12
propagation because what should the backward pass be here in this case we
39:14
backward pass be here in this case we
39:14
backward pass be here in this case we have the largest tensor which is 32 by
39:16
have the largest tensor which is 32 by
39:16
have the largest tensor which is 32 by 27 and in each row we find the maximum
39:18
27 and in each row we find the maximum
39:18
27 and in each row we find the maximum value and then that value gets plucked
39:20
value and then that value gets plucked
39:20
value and then that value gets plucked out into loaded Maxis and so intuitively
39:24
out into loaded Maxis and so intuitively
39:24
out into loaded Maxis and so intuitively um basically the derivative flowing
39:27
um basically the derivative flowing
39:27
um basically the derivative flowing through here then should be one
39:31
through here then should be one
39:31
through here then should be one times the look of derivatives is 1 for
39:34
times the look of derivatives is 1 for
39:34
times the look of derivatives is 1 for the appropriate entry that was plucked
39:35
the appropriate entry that was plucked
39:35
the appropriate entry that was plucked out
39:36
out
39:36
out and then times the global derivative of
39:39
and then times the global derivative of
39:39
and then times the global derivative of the logic axis
39:40
the logic axis
39:40
the logic axis so really what we're doing here if you
39:42
so really what we're doing here if you
39:42
so really what we're doing here if you think through it is we need to take the
39:44
think through it is we need to take the
39:44
think through it is we need to take the deloachet Maxis and we need to scatter
39:46
deloachet Maxis and we need to scatter
39:46
deloachet Maxis and we need to scatter it to the correct positions in these
39:50
it to the correct positions in these
39:50
it to the correct positions in these logits from where the maximum values
39:52
logits from where the maximum values
39:52
logits from where the maximum values came
39:53
came
39:53
came and so
39:54
and so
39:54
and so um
39:56
um
39:56
um I came up with one line of code sort of
39:58
I came up with one line of code sort of
39:58
I came up with one line of code sort of that does that let me just erase a bunch
39:59
that does that let me just erase a bunch
39:59
that does that let me just erase a bunch of stuff here so the line of uh you
40:02
of stuff here so the line of uh you
40:02
of stuff here so the line of uh you could do it kind of very similar to what
40:03
could do it kind of very similar to what
40:03
could do it kind of very similar to what we've done here where we create a zeros
40:05
we've done here where we create a zeros
40:05
we've done here where we create a zeros and then we populate uh the correct
40:07
and then we populate uh the correct
40:07
and then we populate uh the correct elements uh so we use the indices here
40:10
elements uh so we use the indices here
40:10
elements uh so we use the indices here and we would set them to be one but you
40:13
and we would set them to be one but you
40:13
and we would set them to be one but you can also use one hot
40:15
can also use one hot
40:15
can also use one hot so F dot one hot and then I'm taking the
40:18
so F dot one hot and then I'm taking the
40:18
so F dot one hot and then I'm taking the lowest of Max over the First Dimension
40:20
lowest of Max over the First Dimension
40:21
lowest of Max over the First Dimension dot indices and I'm telling uh pytorch
40:24
dot indices and I'm telling uh pytorch
40:24
dot indices and I'm telling uh pytorch that the dimension of every one of these
40:27
that the dimension of every one of these
40:27
that the dimension of every one of these tensors should be
40:29
tensors should be
40:29
tensors should be um
40:29
um
40:29
um 27 and so what this is going to do
40:33
27 and so what this is going to do
40:33
27 and so what this is going to do is okay I apologize this is crazy filthy
40:37
is okay I apologize this is crazy filthy
40:37
is okay I apologize this is crazy filthy that I am sure of this
40:39
that I am sure of this
40:39
that I am sure of this it's really just a an array of where the
40:41
it's really just a an array of where the
40:41
it's really just a an array of where the Maxes came from in each row and that
40:44
Maxes came from in each row and that
40:44
Maxes came from in each row and that element is one and the all the other
40:45
element is one and the all the other
40:45
element is one and the all the other elements are zero so it's a one-half
40:47
elements are zero so it's a one-half
40:47
elements are zero so it's a one-half Vector in each row and these indices are
40:50
Vector in each row and these indices are
40:50
Vector in each row and these indices are now populating a single one in the
40:53
now populating a single one in the
40:53
now populating a single one in the proper place
40:54
proper place
40:54
proper place and then what I'm doing here is I'm
40:56
and then what I'm doing here is I'm
40:56
and then what I'm doing here is I'm multiplying by the logit Maxis and keep
40:58
multiplying by the logit Maxis and keep
40:58
multiplying by the logit Maxis and keep in mind that this is a column
41:01
in mind that this is a column
41:01
in mind that this is a column of 32 by 1. and so when I'm doing this
41:05
of 32 by 1. and so when I'm doing this
41:05
of 32 by 1. and so when I'm doing this times the logic Maxis the logic Maxes
41:08
times the logic Maxis the logic Maxes
41:08
times the logic Maxis the logic Maxes will broadcast and that column will you
41:10
will broadcast and that column will you
41:10
will broadcast and that column will you know get replicated and in an element
41:12
know get replicated and in an element
41:12
know get replicated and in an element wise multiply will ensure that each of
41:14
wise multiply will ensure that each of
41:15
wise multiply will ensure that each of these just gets routed to whichever one
41:17
these just gets routed to whichever one
41:17
these just gets routed to whichever one of these bits is turned on
41:19
of these bits is turned on
41:19
of these bits is turned on and so that's another way to implement
41:21
and so that's another way to implement
41:21
and so that's another way to implement uh this kind of a this kind of a
41:23
uh this kind of a this kind of a
41:23
uh this kind of a this kind of a operation and both of these can be used
41:26
operation and both of these can be used
41:26
operation and both of these can be used I just thought I would show an
41:28
I just thought I would show an
41:28
I just thought I would show an equivalent way to do it and I'm using
41:30
equivalent way to do it and I'm using
41:30
equivalent way to do it and I'm using plus equals because we already
41:31
plus equals because we already
41:31
plus equals because we already calculated the logits here and this is
41:33
calculated the logits here and this is
41:33
calculated the logits here and this is not the second branch
41:35
not the second branch
41:35
not the second branch so let's
41:37
so let's
41:37
so let's look at logits and make sure that this
41:39
look at logits and make sure that this
41:39
look at logits and make sure that this is correct
41:40
is correct
41:40
is correct and we see that we have exactly the
41:42
and we see that we have exactly the
41:42
and we see that we have exactly the correct answer
41:44
correct answer
41:44
correct answer next up we want to continue with logits
41:46
next up we want to continue with logits
41:46
next up we want to continue with logits here that is an outcome of a matrix
41:49
here that is an outcome of a matrix
41:49
here that is an outcome of a matrix multiplication and a bias offset in this
41:51
multiplication and a bias offset in this
41:51
multiplication and a bias offset in this linear layer
41:53
linear layer
41:53
linear layer so I've printed out the shapes of all
41:56
so I've printed out the shapes of all
41:56
so I've printed out the shapes of all these intermediate tensors we see that
41:58
these intermediate tensors we see that
41:58
these intermediate tensors we see that logits is of course 32 by 27 as we've
42:00
logits is of course 32 by 27 as we've
42:00
logits is of course 32 by 27 as we've just seen
42:01
just seen
42:01
just seen then the H here is 32 by 64. so these
42:05
then the H here is 32 by 64. so these
42:05
then the H here is 32 by 64. so these are 64 dimensional hidden States and
42:08
are 64 dimensional hidden States and
42:08
are 64 dimensional hidden States and then this W Matrix projects those 64
42:10
then this W Matrix projects those 64
42:10
then this W Matrix projects those 64 dimensional vectors into 27 dimensions
42:12
dimensional vectors into 27 dimensions
42:12
dimensional vectors into 27 dimensions and then there's a 27 dimensional offset
42:15
and then there's a 27 dimensional offset
42:15
and then there's a 27 dimensional offset which is a one-dimensional vector
42:18
which is a one-dimensional vector
42:18
which is a one-dimensional vector now we should note that this plus here
42:20
now we should note that this plus here
42:20
now we should note that this plus here actually broadcasts because H multiplied
42:23
actually broadcasts because H multiplied
42:23
actually broadcasts because H multiplied by by W2 will give us a 32 by 27. and so
42:27
by by W2 will give us a 32 by 27. and so
42:27
by by W2 will give us a 32 by 27. and so then this plus B2 is a 27 dimensional
42:31
then this plus B2 is a 27 dimensional
42:31
then this plus B2 is a 27 dimensional lecture here
42:32
lecture here
42:32
lecture here now in the rules of broadcasting what's
42:33
now in the rules of broadcasting what's
42:33
now in the rules of broadcasting what's going to happen with this bias Vector is
42:35
going to happen with this bias Vector is
42:35
going to happen with this bias Vector is that this one-dimensional Vector of 27
42:37
that this one-dimensional Vector of 27
42:37
that this one-dimensional Vector of 27 will get aligned with a padded dimension
42:41
will get aligned with a padded dimension
42:41
will get aligned with a padded dimension of one on the left and it will basically
42:43
of one on the left and it will basically
42:43
of one on the left and it will basically become a row vector and then it will get
42:45
become a row vector and then it will get
42:45
become a row vector and then it will get replicated vertically 32 times to make
42:48
replicated vertically 32 times to make
42:48
replicated vertically 32 times to make it 32 by 27 and then there's an
42:50
it 32 by 27 and then there's an
42:50
it 32 by 27 and then there's an element-wise multiply
42:52
element-wise multiply
42:52
element-wise multiply now
42:54
now
42:54
now the question is how do we back propagate
42:56
the question is how do we back propagate
42:56
the question is how do we back propagate from logits to the hidden States the
42:59
from logits to the hidden States the
42:59
from logits to the hidden States the weight Matrix W2 and the bias B2
43:02
weight Matrix W2 and the bias B2
43:02
weight Matrix W2 and the bias B2 and you might think that we need to go
43:03
and you might think that we need to go
43:03
and you might think that we need to go to some Matrix calculus and then we have
43:07
to some Matrix calculus and then we have
43:07
to some Matrix calculus and then we have to look up the derivative for a matrix
43:09
to look up the derivative for a matrix
43:09
to look up the derivative for a matrix multiplication but actually you don't
43:11
multiplication but actually you don't
43:11
multiplication but actually you don't have to do any of that and you can go
43:12
have to do any of that and you can go
43:12
have to do any of that and you can go back to First principles and derive this
43:14
back to First principles and derive this
43:14
back to First principles and derive this yourself on a piece of paper and
43:17
yourself on a piece of paper and
43:17
yourself on a piece of paper and specifically what I like to do and I
43:18
specifically what I like to do and I
43:18
specifically what I like to do and I what I find works well for me is you
43:20
what I find works well for me is you
43:20
what I find works well for me is you find a specific small example that you
43:23
find a specific small example that you
43:23
find a specific small example that you then fully write out and then in the
43:25
then fully write out and then in the
43:25
then fully write out and then in the process of analyzing how that individual
43:27
process of analyzing how that individual
43:27
process of analyzing how that individual small example works you will understand
43:28
small example works you will understand
43:28
small example works you will understand the broader pattern and you'll be able
43:30
the broader pattern and you'll be able
43:30
the broader pattern and you'll be able to generalize and write out the full
43:32
to generalize and write out the full
43:32
to generalize and write out the full general formula for what how these
43:35
general formula for what how these
43:35
general formula for what how these derivatives flow in an expression like
43:37
derivatives flow in an expression like
43:37
derivatives flow in an expression like this so let's try that out
43:39
this so let's try that out
43:39
this so let's try that out so pardon the low budget production here
43:41
so pardon the low budget production here
43:41
so pardon the low budget production here but what I've done here is I'm writing
43:43
but what I've done here is I'm writing
43:43
but what I've done here is I'm writing it out on a piece of paper really what
43:45
it out on a piece of paper really what
43:45
it out on a piece of paper really what we are interested in is we have a
43:46
we are interested in is we have a
43:46
we are interested in is we have a multiply B plus C and that creates a d
43:50
multiply B plus C and that creates a d
43:50
multiply B plus C and that creates a d and we have the derivative of the loss
43:53
and we have the derivative of the loss
43:53
and we have the derivative of the loss with respect to D and we'd like to know
43:54
with respect to D and we'd like to know
43:54
with respect to D and we'd like to know what the derivative of the losses with
43:55
what the derivative of the losses with
43:55
what the derivative of the losses with respect to a b and c
43:57
respect to a b and c
43:57
respect to a b and c now these here are little
43:59
now these here are little
44:00
now these here are little two-dimensional examples of a matrix
44:01
two-dimensional examples of a matrix
44:01
two-dimensional examples of a matrix multiplication Two by Two Times a two by
44:03
multiplication Two by Two Times a two by
44:03
multiplication Two by Two Times a two by two
44:04
two
44:04
two plus a 2 a vector of just two elements
44:07
plus a 2 a vector of just two elements
44:07
plus a 2 a vector of just two elements C1 and C2 gives me a two by two
44:10
C1 and C2 gives me a two by two
44:10
C1 and C2 gives me a two by two now notice here that I have a bias
44:14
now notice here that I have a bias
44:14
now notice here that I have a bias Vector here called C and the bisex
44:17
Vector here called C and the bisex
44:17
Vector here called C and the bisex vector is C1 and C2 but as I described
44:19
vector is C1 and C2 but as I described
44:19
vector is C1 and C2 but as I described over here that bias Vector will become a
44:21
over here that bias Vector will become a
44:21
over here that bias Vector will become a row Vector in the broadcasting and will
44:23
row Vector in the broadcasting and will
44:23
row Vector in the broadcasting and will replicate vertically so that's what's
44:24
replicate vertically so that's what's
44:24
replicate vertically so that's what's happening here as well C1 C2 is
44:27
happening here as well C1 C2 is
44:27
happening here as well C1 C2 is replicated vertically and we see how we
44:29
replicated vertically and we see how we
44:29
replicated vertically and we see how we have two rows of C1 C2 as a result
44:33
have two rows of C1 C2 as a result
44:33
have two rows of C1 C2 as a result so now when I say write it out I just
44:35
so now when I say write it out I just
44:35
so now when I say write it out I just mean like this basically break up this
44:37
mean like this basically break up this
44:37
mean like this basically break up this matrix multiplication into the actual
44:40
matrix multiplication into the actual
44:40
matrix multiplication into the actual thing that that's going on under the
44:41
thing that that's going on under the
44:41
thing that that's going on under the hood so as a result of matrix
44:44
hood so as a result of matrix
44:44
hood so as a result of matrix multiplication and how it works d11 is
44:46
multiplication and how it works d11 is
44:46
multiplication and how it works d11 is the result of a DOT product between the
44:48
the result of a DOT product between the
44:48
the result of a DOT product between the first row of a and the First Column of B
44:50
first row of a and the First Column of B
44:51
first row of a and the First Column of B so a11 b11 plus a12 B21 plus C1
44:57
so a11 b11 plus a12 B21 plus C1
44:57
so a11 b11 plus a12 B21 plus C1 and so on so forth for all the other
44:59
and so on so forth for all the other
44:59
and so on so forth for all the other elements of D and once you actually
45:02
elements of D and once you actually
45:02
elements of D and once you actually write it out it becomes obvious this is
45:03
write it out it becomes obvious this is
45:03
write it out it becomes obvious this is just a bunch of multipliers and
45:06
just a bunch of multipliers and
45:06
just a bunch of multipliers and um adds and we know from micrograd how
45:09
um adds and we know from micrograd how
45:09
um adds and we know from micrograd how to differentiate multiplies and adds and
45:11
to differentiate multiplies and adds and
45:11
to differentiate multiplies and adds and so this is not scary anymore it's not
45:13
so this is not scary anymore it's not
45:13
so this is not scary anymore it's not just matrix multiplication it's just uh
45:15
just matrix multiplication it's just uh
45:15
just matrix multiplication it's just uh tedious unfortunately but this is
45:17
tedious unfortunately but this is
45:17
tedious unfortunately but this is completely tractable we have DL by D for
45:20
completely tractable we have DL by D for
45:20
completely tractable we have DL by D for all of these and we want DL by uh all
45:23
all of these and we want DL by uh all
45:23
all of these and we want DL by uh all these little other variables so how do
45:25
these little other variables so how do
45:25
these little other variables so how do we achieve that and how do we actually
45:26
we achieve that and how do we actually
45:26
we achieve that and how do we actually get the gradients okay so the low budget
45:29
get the gradients okay so the low budget
45:29
get the gradients okay so the low budget production continues here
45:30
production continues here
45:30
production continues here so let's for example derive the
45:32
so let's for example derive the
45:32
so let's for example derive the derivative of the loss with respect to
45:34
derivative of the loss with respect to
45:34
derivative of the loss with respect to a11
45:36
a11
45:36
a11 we see here that a11 occurs twice in our
45:38
we see here that a11 occurs twice in our
45:38
we see here that a11 occurs twice in our simple expression right here right here
45:40
simple expression right here right here
45:40
simple expression right here right here and influences d11 and D12
45:43
and influences d11 and D12
45:43
and influences d11 and D12 . so this is so what is DL by d a one
45:46
. so this is so what is DL by d a one
45:46
. so this is so what is DL by d a one one well it's DL by d11 times the local
45:51
one well it's DL by d11 times the local
45:51
one well it's DL by d11 times the local derivative of d11 which in this case is
45:53
derivative of d11 which in this case is
45:53
derivative of d11 which in this case is just b11 because that's what's
45:55
just b11 because that's what's
45:55
just b11 because that's what's multiplying a11 here
45:57
multiplying a11 here
45:57
multiplying a11 here so uh and likewise here the local
46:00
so uh and likewise here the local
46:00
so uh and likewise here the local derivative of D12 with respect to a11 is
46:02
derivative of D12 with respect to a11 is
46:02
derivative of D12 with respect to a11 is just B12 and so B12 well in the chain
46:05
just B12 and so B12 well in the chain
46:05
just B12 and so B12 well in the chain rule therefore multiply the L by d 1 2.
46:08
rule therefore multiply the L by d 1 2.
46:08
rule therefore multiply the L by d 1 2. and then because a11 is used both to
46:11
and then because a11 is used both to
46:11
and then because a11 is used both to produce d11 and D12 we need to add up
46:15
produce d11 and D12 we need to add up
46:15
produce d11 and D12 we need to add up the contributions of both of those sort
46:18
the contributions of both of those sort
46:18
the contributions of both of those sort of chains that are running in parallel
46:20
of chains that are running in parallel
46:20
of chains that are running in parallel and that's why we get a plus just adding
46:22
and that's why we get a plus just adding
46:22
and that's why we get a plus just adding up those two
46:24
up those two
46:24
up those two um those two contributions and that
46:26
um those two contributions and that
46:26
um those two contributions and that gives us DL by d a one one we can do the
46:29
gives us DL by d a one one we can do the
46:29
gives us DL by d a one one we can do the exact same analysis for the other one
46:31
exact same analysis for the other one
46:31
exact same analysis for the other one for all the other elements of a and when
46:34
for all the other elements of a and when
46:34
for all the other elements of a and when you simply write it out it's just super
46:36
you simply write it out it's just super
46:36
you simply write it out it's just super simple
46:37
simple
46:37
simple um taking of gradients on you know
46:40
um taking of gradients on you know
46:40
um taking of gradients on you know expressions like this
46:42
expressions like this
46:42
expressions like this you find that
46:44
you find that
46:44
you find that this Matrix DL by D A that we're after
46:47
this Matrix DL by D A that we're after
46:47
this Matrix DL by D A that we're after right if we just arrange all the all of
46:49
right if we just arrange all the all of
46:49
right if we just arrange all the all of them in the same shape as a takes so a
46:52
them in the same shape as a takes so a
46:52
them in the same shape as a takes so a is just too much Matrix so d l by D A
46:55
is just too much Matrix so d l by D A
46:55
is just too much Matrix so d l by D A here will be also just the same shape
46:59
here will be also just the same shape
46:59
here will be also just the same shape tester with the derivatives now so deal
47:03
tester with the derivatives now so deal
47:03
tester with the derivatives now so deal by D a11 Etc
47:05
by D a11 Etc
47:05
by D a11 Etc and we see that actually we can express
47:06
and we see that actually we can express
47:06
and we see that actually we can express what we've written out here as a matrix
47:09
what we've written out here as a matrix
47:09
what we've written out here as a matrix multiplied
47:10
multiplied
47:10
multiplied and so it just so happens that D all by
47:13
and so it just so happens that D all by
47:13
and so it just so happens that D all by that all of these formulas that we've
47:15
that all of these formulas that we've
47:15
that all of these formulas that we've derived here by taking gradients can
47:17
derived here by taking gradients can
47:17
derived here by taking gradients can actually be expressed as a matrix
47:19
actually be expressed as a matrix
47:19
actually be expressed as a matrix multiplication and in particular we see
47:21
multiplication and in particular we see
47:21
multiplication and in particular we see that it is the matrix multiplication of
47:22
that it is the matrix multiplication of
47:22
that it is the matrix multiplication of these two array matrices
47:25
these two array matrices
47:25
these two array matrices so it is the um DL by D and then Matrix
47:30
so it is the um DL by D and then Matrix
47:30
so it is the um DL by D and then Matrix multiplying B but B transpose actually
47:32
multiplying B but B transpose actually
47:32
multiplying B but B transpose actually so you see that B21 and b12 have changed
47:37
so you see that B21 and b12 have changed
47:37
so you see that B21 and b12 have changed place
47:38
place
47:38
place whereas before we had of course b11 B12
47:41
whereas before we had of course b11 B12
47:41
whereas before we had of course b11 B12 B2 on B22 so you see that this other
47:45
B2 on B22 so you see that this other
47:45
B2 on B22 so you see that this other Matrix B is transposed
47:47
Matrix B is transposed
47:47
Matrix B is transposed and so basically what we have long story
47:49
and so basically what we have long story
47:49
and so basically what we have long story short just by doing very simple
47:50
short just by doing very simple
47:50
short just by doing very simple reasoning here by breaking up the
47:52
reasoning here by breaking up the
47:52
reasoning here by breaking up the expression in the case of a very simple
47:54
expression in the case of a very simple
47:54
expression in the case of a very simple example is that DL by d a is which is
47:58
example is that DL by d a is which is
47:58
example is that DL by d a is which is this is simply equal to DL by DD Matrix
48:02
this is simply equal to DL by DD Matrix
48:02
this is simply equal to DL by DD Matrix multiplied with B transpose
48:05
so that is what we have so far now we
48:08
so that is what we have so far now we
48:08
so that is what we have so far now we also want the derivative with respect to
48:10
also want the derivative with respect to
48:10
also want the derivative with respect to um B and C now
48:13
um B and C now
48:13
um B and C now for B I'm not actually doing the full
48:15
for B I'm not actually doing the full
48:15
for B I'm not actually doing the full derivation because honestly it's um it's
48:18
derivation because honestly it's um it's
48:18
derivation because honestly it's um it's not deep it's just uh annoying it's
48:20
not deep it's just uh annoying it's
48:20
not deep it's just uh annoying it's exhausting you can actually do this
48:22
exhausting you can actually do this
48:22
exhausting you can actually do this analysis yourself you'll also find that
48:24
analysis yourself you'll also find that
48:24
analysis yourself you'll also find that if you take this these expressions and
48:26
if you take this these expressions and
48:26
if you take this these expressions and you differentiate with respect to b
48:27
you differentiate with respect to b
48:27
you differentiate with respect to b instead of a you will find that DL by DB
48:30
instead of a you will find that DL by DB
48:30
instead of a you will find that DL by DB is also a matrix multiplication in this
48:33
is also a matrix multiplication in this
48:33
is also a matrix multiplication in this case you have to take the Matrix a and
48:35
case you have to take the Matrix a and
48:35
case you have to take the Matrix a and transpose it and Matrix multiply that
48:37
transpose it and Matrix multiply that
48:37
transpose it and Matrix multiply that with bl by DD
48:39
with bl by DD
48:39
with bl by DD and that's what gives you a deal by DB
48:42
and that's what gives you a deal by DB
48:42
and that's what gives you a deal by DB and then here for the offsets C1 and C2
48:46
and then here for the offsets C1 and C2
48:46
and then here for the offsets C1 and C2 if you again just differentiate with
48:47
if you again just differentiate with
48:47
if you again just differentiate with respect to C1 you will find an
48:50
respect to C1 you will find an
48:50
respect to C1 you will find an expression like this
48:52
expression like this
48:52
expression like this and C2 an expression like this
48:55
and C2 an expression like this
48:55
and C2 an expression like this and basically you'll find the DL by DC
48:57
and basically you'll find the DL by DC
48:57
and basically you'll find the DL by DC is simply because they're just
48:59
is simply because they're just
48:59
is simply because they're just offsetting these Expressions you just
49:01
offsetting these Expressions you just
49:01
offsetting these Expressions you just have to take the deal by DD Matrix
49:04
have to take the deal by DD Matrix
49:04
have to take the deal by DD Matrix of the derivatives of D and you just
49:07
of the derivatives of D and you just
49:07
of the derivatives of D and you just have to sum across the columns and that
49:11
have to sum across the columns and that
49:11
have to sum across the columns and that gives you the derivatives for C
49:13
gives you the derivatives for C
49:13
gives you the derivatives for C so long story short
49:15
so long story short
49:15
so long story short the backward Paths of a matrix multiply
49:18
the backward Paths of a matrix multiply
49:18
the backward Paths of a matrix multiply is a matrix multiply
49:20
is a matrix multiply
49:20
is a matrix multiply and instead of just like we had D equals
49:22
and instead of just like we had D equals
49:22
and instead of just like we had D equals a times B plus C in the scalar case uh
49:25
a times B plus C in the scalar case uh
49:25
a times B plus C in the scalar case uh we sort of like arrive at something very
49:27
we sort of like arrive at something very
49:27
we sort of like arrive at something very very similar but now uh with a matrix
49:29
very similar but now uh with a matrix
49:29
very similar but now uh with a matrix multiplication instead of a scalar
49:31
multiplication instead of a scalar
49:31
multiplication instead of a scalar multiplication
49:32
multiplication
49:32
multiplication so the derivative of D with respect to a
49:36
so the derivative of D with respect to a
49:36
so the derivative of D with respect to a is
49:37
is
49:37
is DL by DD Matrix multiplied B trespose
49:41
DL by DD Matrix multiplied B trespose
49:41
DL by DD Matrix multiplied B trespose and here it's a transpose multiply deal
49:44
and here it's a transpose multiply deal
49:44
and here it's a transpose multiply deal by DD but in both cases it's a matrix
49:46
by DD but in both cases it's a matrix
49:46
by DD but in both cases it's a matrix multiplication with the derivative and
49:49
multiplication with the derivative and
49:49
multiplication with the derivative and the other term in the multiplication
49:53
the other term in the multiplication
49:53
the other term in the multiplication and for C it is a sum
49:55
and for C it is a sum
49:55
and for C it is a sum now I'll tell you a secret I can never
49:58
now I'll tell you a secret I can never
49:58
now I'll tell you a secret I can never remember the formulas that we just
50:00
remember the formulas that we just
50:00
remember the formulas that we just arrived for back proper gain information
50:01
arrived for back proper gain information
50:01
arrived for back proper gain information multiplication and I can back propagate
50:03
multiplication and I can back propagate
50:03
multiplication and I can back propagate through these Expressions just fine and
50:05
through these Expressions just fine and
50:05
through these Expressions just fine and the reason this works is because the
50:07
the reason this works is because the
50:07
the reason this works is because the dimensions have to work out
50:08
dimensions have to work out
50:09
dimensions have to work out uh so let me give you an example say I
50:11
uh so let me give you an example say I
50:11
uh so let me give you an example say I want to create DH
50:13
want to create DH
50:13
want to create DH then what should the H be number one I
50:16
then what should the H be number one I
50:16
then what should the H be number one I have to know that the shape of DH must
50:19
have to know that the shape of DH must
50:19
have to know that the shape of DH must be the same as the shape of H
50:21
be the same as the shape of H
50:21
be the same as the shape of H and the shape of H is 32 by 64. and then
50:24
and the shape of H is 32 by 64. and then
50:24
and the shape of H is 32 by 64. and then the other piece of information I know is
50:26
the other piece of information I know is
50:26
the other piece of information I know is that DH must be some kind of matrix
50:28
that DH must be some kind of matrix
50:28
that DH must be some kind of matrix multiplication of the logits with W2
50:32
multiplication of the logits with W2
50:32
multiplication of the logits with W2 and delojits is 32 by 27 and W2 is a 64
50:37
and delojits is 32 by 27 and W2 is a 64
50:37
and delojits is 32 by 27 and W2 is a 64 by 27. there is only a single way to
50:40
by 27. there is only a single way to
50:40
by 27. there is only a single way to make the shape work out in this case and
50:43
make the shape work out in this case and
50:43
make the shape work out in this case and it is indeed the correct result in
50:45
it is indeed the correct result in
50:45
it is indeed the correct result in particular here H needs to be 32 by 64.
50:48
particular here H needs to be 32 by 64.
50:48
particular here H needs to be 32 by 64. the only way to achieve that is to take
50:50
the only way to achieve that is to take
50:50
the only way to achieve that is to take a deluges
50:52
a deluges
50:52
a deluges and Matrix multiply it with you see how
50:55
and Matrix multiply it with you see how
50:55
and Matrix multiply it with you see how I have to take W2 but I have to
50:57
I have to take W2 but I have to
50:57
I have to take W2 but I have to transpose it to make the dimensions work
50:58
transpose it to make the dimensions work
50:58
transpose it to make the dimensions work out
50:59
out
50:59
out so w to transpose and it's the only way
51:02
so w to transpose and it's the only way
51:02
so w to transpose and it's the only way to make these to Matrix multiply those
51:04
to make these to Matrix multiply those
51:04
to make these to Matrix multiply those two pieces to make the shapes work out
51:06
two pieces to make the shapes work out
51:06
two pieces to make the shapes work out and that turns out to be the correct
51:08
and that turns out to be the correct
51:08
and that turns out to be the correct formula so if we come here we want DH
51:11
formula so if we come here we want DH
51:11
formula so if we come here we want DH which is d a and we see that d a is DL
51:15
which is d a and we see that d a is DL
51:15
which is d a and we see that d a is DL by DD Matrix multiply B transpose
51:17
by DD Matrix multiply B transpose
51:18
by DD Matrix multiply B transpose so that's Delo just multiply and B is W2
51:21
so that's Delo just multiply and B is W2
51:21
so that's Delo just multiply and B is W2 so W2 transpose which is exactly what we
51:23
so W2 transpose which is exactly what we
51:24
so W2 transpose which is exactly what we have here so there's no need to remember
51:26
have here so there's no need to remember
51:26
have here so there's no need to remember these formulas similarly now if I want
51:30
these formulas similarly now if I want
51:30
these formulas similarly now if I want dw2 well I know that it must be a matrix
51:33
dw2 well I know that it must be a matrix
51:33
dw2 well I know that it must be a matrix multiplication of D logits and H
51:37
multiplication of D logits and H
51:37
multiplication of D logits and H and maybe there's a few transpose like
51:39
and maybe there's a few transpose like
51:39
and maybe there's a few transpose like there's one transpose in there as well
51:40
there's one transpose in there as well
51:40
there's one transpose in there as well and I don't know which way it is so I
51:42
and I don't know which way it is so I
51:42
and I don't know which way it is so I have to come to W2 and I see that its
51:44
have to come to W2 and I see that its
51:44
have to come to W2 and I see that its shape is 64 by 27
51:47
shape is 64 by 27
51:47
shape is 64 by 27 and that has to come from some interest
51:49
and that has to come from some interest
51:49
and that has to come from some interest multiplication of these two
51:51
multiplication of these two
51:51
multiplication of these two and so to get a 64 by 27 I need to take
51:55
and so to get a 64 by 27 I need to take
51:55
and so to get a 64 by 27 I need to take um
51:56
um
51:56
um H I need to transpose it
51:59
H I need to transpose it
51:59
H I need to transpose it and then I need to Matrix multiply it
52:01
and then I need to Matrix multiply it
52:01
and then I need to Matrix multiply it um so that will become 64 by 32 and then
52:04
um so that will become 64 by 32 and then
52:04
um so that will become 64 by 32 and then I need to make sure to multiply with the
52:05
I need to make sure to multiply with the
52:05
I need to make sure to multiply with the 32 by 27 and that's going to give me a
52:07
32 by 27 and that's going to give me a
52:07
32 by 27 and that's going to give me a 64 by 27. so I need to make sure it's
52:09
64 by 27. so I need to make sure it's
52:09
64 by 27. so I need to make sure it's multiplied this with the logist that
52:11
multiplied this with the logist that
52:11
multiplied this with the logist that shape just like that that's the only way
52:13
shape just like that that's the only way
52:13
shape just like that that's the only way to make the dimensions work out and just
52:15
to make the dimensions work out and just
52:15
to make the dimensions work out and just use matrix multiplication and if we come
52:17
use matrix multiplication and if we come
52:17
use matrix multiplication and if we come here we see that that's exactly what's
52:19
here we see that that's exactly what's
52:19
here we see that that's exactly what's here so a transpose a for us is H
52:22
here so a transpose a for us is H
52:23
here so a transpose a for us is H multiplied with deloaches
52:25
multiplied with deloaches
52:25
multiplied with deloaches so that's W2 and then db2
52:30
so that's W2 and then db2
52:30
so that's W2 and then db2 is just the um
52:33
is just the um
52:33
is just the um vertical sum and actually in the same
52:35
vertical sum and actually in the same
52:35
vertical sum and actually in the same way there's only one way to make the
52:37
way there's only one way to make the
52:37
way there's only one way to make the shapes work out I don't have to remember
52:38
shapes work out I don't have to remember
52:38
shapes work out I don't have to remember that it's a vertical Sum along the zero
52:40
that it's a vertical Sum along the zero
52:40
that it's a vertical Sum along the zero axis because that's the only way that
52:42
axis because that's the only way that
52:42
axis because that's the only way that this makes sense because B2 shape is 27
52:45
this makes sense because B2 shape is 27
52:45
this makes sense because B2 shape is 27 so in order to get a um delugits
52:50
so in order to get a um delugits
52:50
so in order to get a um delugits here is 30 by 27 so knowing that it's
52:54
here is 30 by 27 so knowing that it's
52:54
here is 30 by 27 so knowing that it's just sum over deloaches in some
52:56
just sum over deloaches in some
52:56
just sum over deloaches in some Direction
52:59
that direction must be zero because I
53:02
that direction must be zero because I
53:02
that direction must be zero because I need to eliminate this Dimension so it's
53:04
need to eliminate this Dimension so it's
53:04
need to eliminate this Dimension so it's this
53:05
this
53:06
this so this is so let's kind of like the
53:08
so this is so let's kind of like the
53:08
so this is so let's kind of like the hacky way let me copy paste and delete
53:10
hacky way let me copy paste and delete
53:10
hacky way let me copy paste and delete that and let me swing over here and this
53:13
that and let me swing over here and this
53:13
that and let me swing over here and this is our backward pass for the linear
53:14
is our backward pass for the linear
53:14
is our backward pass for the linear layer uh hopefully
53:17
layer uh hopefully
53:17
layer uh hopefully so now let's uncomment
53:19
so now let's uncomment
53:19
so now let's uncomment these three and we're checking that we
53:21
these three and we're checking that we
53:21
these three and we're checking that we got all the three derivatives correct
53:24
got all the three derivatives correct
53:24
got all the three derivatives correct and run
53:26
and run
53:26
and run and we see that h wh and B2 are all
53:30
and we see that h wh and B2 are all
53:30
and we see that h wh and B2 are all exactly correct so we back propagated
53:32
exactly correct so we back propagated
53:33
exactly correct so we back propagated through a linear layer
53:36
now next up we have derivative for the h
53:39
now next up we have derivative for the h
53:39
now next up we have derivative for the h already and we need to back propagate
53:41
already and we need to back propagate
53:41
already and we need to back propagate through 10h into h preact
53:43
through 10h into h preact
53:43
through 10h into h preact so we want to derive DH preact
53:47
so we want to derive DH preact
53:47
so we want to derive DH preact and here we have to back propagate
53:48
and here we have to back propagate
53:48
and here we have to back propagate through a 10 H and we've already done
53:50
through a 10 H and we've already done
53:50
through a 10 H and we've already done this in micrograd and we remember that
53:52
this in micrograd and we remember that
53:52
this in micrograd and we remember that 10h has a very simple backward formula
53:54
10h has a very simple backward formula
53:54
10h has a very simple backward formula now unfortunately if I just put in D by
53:56
now unfortunately if I just put in D by
53:56
now unfortunately if I just put in D by DX of 10 h of X into both from alpha it
53:59
DX of 10 h of X into both from alpha it
53:59
DX of 10 h of X into both from alpha it lets us down it tells us that it's a
54:00
lets us down it tells us that it's a
54:00
lets us down it tells us that it's a hyperbolic secant function squared of X
54:03
hyperbolic secant function squared of X
54:03
hyperbolic secant function squared of X it's not exactly helpful but luckily
54:06
it's not exactly helpful but luckily
54:06
it's not exactly helpful but luckily Google image search does not let us down
54:08
Google image search does not let us down
54:08
Google image search does not let us down and it gives us the simpler formula and
54:10
and it gives us the simpler formula and
54:10
and it gives us the simpler formula and in particular if you have that a is
54:12
in particular if you have that a is
54:12
in particular if you have that a is equal to 10 h of Z then d a by DZ by
54:16
equal to 10 h of Z then d a by DZ by
54:16
equal to 10 h of Z then d a by DZ by propagating through 10 H is just one
54:17
propagating through 10 H is just one
54:17
propagating through 10 H is just one minus a square and take note that 1
54:21
minus a square and take note that 1
54:21
minus a square and take note that 1 minus a square a here is the output of
54:23
minus a square a here is the output of
54:23
minus a square a here is the output of the 10h not the input to the 10h Z so
54:27
the 10h not the input to the 10h Z so
54:27
the 10h not the input to the 10h Z so the D A by DZ is here formulated in
54:29
the D A by DZ is here formulated in
54:29
the D A by DZ is here formulated in terms of the output of that 10h
54:31
terms of the output of that 10h
54:31
terms of the output of that 10h and here also in Google image search we
54:34
and here also in Google image search we
54:34
and here also in Google image search we have the full derivation if you want to
54:35
have the full derivation if you want to
54:35
have the full derivation if you want to actually take the actual definition of
54:38
actually take the actual definition of
54:38
actually take the actual definition of 10h and work through the math to figure
54:39
10h and work through the math to figure
54:39
10h and work through the math to figure out 1 minus standard square of Z
54:42
out 1 minus standard square of Z
54:42
out 1 minus standard square of Z so 1 minus a square is the local
54:45
so 1 minus a square is the local
54:45
so 1 minus a square is the local derivative in our case that is 1 minus
54:49
derivative in our case that is 1 minus
54:49
derivative in our case that is 1 minus uh the output of 10 H squared which here
54:52
uh the output of 10 H squared which here
54:52
uh the output of 10 H squared which here is H
54:53
is H
54:53
is H so it's h squared and that is the local
54:56
so it's h squared and that is the local
54:56
so it's h squared and that is the local derivative and then times the chain rule
54:58
derivative and then times the chain rule
54:58
derivative and then times the chain rule DH
55:00
DH
55:00
DH so that is going to be our candidate
55:02
so that is going to be our candidate
55:02
so that is going to be our candidate implementation so if we come here
55:05
implementation so if we come here
55:05
implementation so if we come here and then uncomment this let's hope for
55:08
and then uncomment this let's hope for
55:08
and then uncomment this let's hope for the best
55:09
the best
55:09
the best and we have the right answer
55:12
and we have the right answer
55:12
and we have the right answer okay next up we have DH preact and we
55:15
okay next up we have DH preact and we
55:15
okay next up we have DH preact and we want to back propagate into the gain the
55:17
want to back propagate into the gain the
55:17
want to back propagate into the gain the B and raw and the B and bias
55:19
B and raw and the B and bias
55:19
B and raw and the B and bias so here this is the bathroom parameters
55:21
so here this is the bathroom parameters
55:21
so here this is the bathroom parameters being gained in bias inside the bash
55:23
being gained in bias inside the bash
55:23
being gained in bias inside the bash term that take the B and raw that is
55:25
term that take the B and raw that is
55:25
term that take the B and raw that is exact unit caution and then scale it and
55:28
exact unit caution and then scale it and
55:28
exact unit caution and then scale it and shift it
55:29
shift it
55:29
shift it and these are the parameters of The
55:30
and these are the parameters of The
55:30
and these are the parameters of The Bachelor now here we have a
55:33
Bachelor now here we have a
55:33
Bachelor now here we have a multiplication but it's worth noting
55:35
multiplication but it's worth noting
55:35
multiplication but it's worth noting that this multiply is very very
55:36
that this multiply is very very
55:36
that this multiply is very very different from this Matrix multiply here
55:38
different from this Matrix multiply here
55:38
different from this Matrix multiply here Matrix multiply are DOT products between
55:41
Matrix multiply are DOT products between
55:41
Matrix multiply are DOT products between rows and Columns of these matrices
55:43
rows and Columns of these matrices
55:43
rows and Columns of these matrices involved this is an element twice
55:45
involved this is an element twice
55:45
involved this is an element twice multiply so things are quite a bit
55:46
multiply so things are quite a bit
55:46
multiply so things are quite a bit simpler
55:47
simpler
55:47
simpler now we do have to be careful with some
55:49
now we do have to be careful with some
55:49
now we do have to be careful with some of the broadcasting happening in this
55:51
of the broadcasting happening in this
55:51
of the broadcasting happening in this line of code though so you see how BN
55:53
line of code though so you see how BN
55:53
line of code though so you see how BN gain and B and bias are 1 by 64. but H
55:58
gain and B and bias are 1 by 64. but H
55:58
gain and B and bias are 1 by 64. but H preact and B and raw are 32 by 64.
56:02
preact and B and raw are 32 by 64.
56:02
preact and B and raw are 32 by 64. so we have to be careful with that and
56:04
so we have to be careful with that and
56:04
so we have to be careful with that and make sure that all the shapes work out
56:05
make sure that all the shapes work out
56:05
make sure that all the shapes work out fine and that the broadcasting is
56:06
fine and that the broadcasting is
56:06
fine and that the broadcasting is correctly back propagated
56:08
correctly back propagated
56:08
correctly back propagated so in particular let's start with the B
56:10
so in particular let's start with the B
56:10
so in particular let's start with the B and Gain so DB and gain should be
56:14
and Gain so DB and gain should be
56:14
and Gain so DB and gain should be and here this is again elementorized
56:17
and here this is again elementorized
56:17
and here this is again elementorized multiply and whenever we have a times b
56:19
multiply and whenever we have a times b
56:19
multiply and whenever we have a times b equals c we saw that the local
56:21
equals c we saw that the local
56:21
equals c we saw that the local derivative here is just if this is a the
56:23
derivative here is just if this is a the
56:23
derivative here is just if this is a the local derivative is just the B the other
56:25
local derivative is just the B the other
56:25
local derivative is just the B the other one so the local derivative is just B
56:27
one so the local derivative is just B
56:27
one so the local derivative is just B and raw and then times chain rule
56:31
and raw and then times chain rule
56:31
and raw and then times chain rule so DH preact
56:34
so DH preact
56:34
so DH preact so this is the candidate gradient now
56:38
so this is the candidate gradient now
56:38
so this is the candidate gradient now again we have to be careful because B
56:40
again we have to be careful because B
56:40
again we have to be careful because B and Gain Is of size 1 by 64. but this
56:44
and Gain Is of size 1 by 64. but this
56:44
and Gain Is of size 1 by 64. but this here would be 32 by 64.
56:48
here would be 32 by 64.
56:48
here would be 32 by 64. and so
56:49
and so
56:49
and so um the correct thing to do in this case
56:51
um the correct thing to do in this case
56:51
um the correct thing to do in this case of course is that b and gain here is a
56:53
of course is that b and gain here is a
56:53
of course is that b and gain here is a rule Vector of 64 numbers it gets
56:55
rule Vector of 64 numbers it gets
56:55
rule Vector of 64 numbers it gets replicated vertically in this operation
56:58
replicated vertically in this operation
56:58
replicated vertically in this operation and so therefore the correct thing to do
57:00
and so therefore the correct thing to do
57:00
and so therefore the correct thing to do is to sum because it's being replicated
57:03
is to sum because it's being replicated
57:03
is to sum because it's being replicated and therefore all the gradients in each
57:06
and therefore all the gradients in each
57:06
and therefore all the gradients in each of the rows that are now flowing
57:07
of the rows that are now flowing
57:07
of the rows that are now flowing backwards need to sum up to that same
57:10
backwards need to sum up to that same
57:10
backwards need to sum up to that same tensor DB and Gain so we have to sum
57:13
tensor DB and Gain so we have to sum
57:13
tensor DB and Gain so we have to sum across all the zero all the examples
57:16
across all the zero all the examples
57:16
across all the zero all the examples basically
57:17
basically
57:17
basically which is the direction in which this
57:19
which is the direction in which this
57:19
which is the direction in which this gets replicated
57:20
gets replicated
57:20
gets replicated and now we have to be also careful
57:21
and now we have to be also careful
57:21
and now we have to be also careful because we
57:23
because we
57:23
because we um being gain is of shape 1 by 64. so in
57:26
um being gain is of shape 1 by 64. so in
57:26
um being gain is of shape 1 by 64. so in fact I need to keep them as true
57:29
fact I need to keep them as true
57:29
fact I need to keep them as true otherwise I would just get 64.
57:31
otherwise I would just get 64.
57:31
otherwise I would just get 64. now I don't actually really remember why
57:34
now I don't actually really remember why
57:34
now I don't actually really remember why the being gain and the BN bias I made
57:36
the being gain and the BN bias I made
57:36
the being gain and the BN bias I made them be 1 by 64.
57:40
them be 1 by 64.
57:40
them be 1 by 64. um
57:41
um
57:41
um but the biases B1 and B2 I just made
57:44
but the biases B1 and B2 I just made
57:44
but the biases B1 and B2 I just made them be one-dimensional vectors they're
57:45
them be one-dimensional vectors they're
57:45
them be one-dimensional vectors they're not two-dimensional tensors so I can't
57:47
not two-dimensional tensors so I can't
57:47
not two-dimensional tensors so I can't recall exactly why I left the gain and
57:51
recall exactly why I left the gain and
57:51
recall exactly why I left the gain and the bias as two-dimensional but it
57:53
the bias as two-dimensional but it
57:53
the bias as two-dimensional but it doesn't really matter as long as you are
57:54
doesn't really matter as long as you are
57:54
doesn't really matter as long as you are consistent and you're keeping it the
57:55
consistent and you're keeping it the
57:55
consistent and you're keeping it the same
57:56
same
57:56
same so in this case we want to keep the
57:58
so in this case we want to keep the
57:58
so in this case we want to keep the dimension so that the tensor shapes work
58:01
dimension so that the tensor shapes work
58:01
dimension so that the tensor shapes work next up we have B and raw so DB and raw
58:05
next up we have B and raw so DB and raw
58:05
next up we have B and raw so DB and raw will be BN gain
58:09
will be BN gain
58:09
will be BN gain multiplying
58:11
multiplying
58:11
multiplying dhreact that's our chain rule now what
58:15
dhreact that's our chain rule now what
58:15
dhreact that's our chain rule now what about the
58:17
about the
58:17
about the um
58:17
um
58:18
um dimensions of this we have to be careful
58:20
dimensions of this we have to be careful
58:20
dimensions of this we have to be careful right so DH preact is 32 by 64. B and
58:24
right so DH preact is 32 by 64. B and
58:24
right so DH preact is 32 by 64. B and gain is 1 by 64. so it will just get
58:27
gain is 1 by 64. so it will just get
58:27
gain is 1 by 64. so it will just get replicated and to create this
58:29
replicated and to create this
58:29
replicated and to create this multiplication which is the correct
58:31
multiplication which is the correct
58:31
multiplication which is the correct thing because in a forward pass it also
58:33
thing because in a forward pass it also
58:33
thing because in a forward pass it also gets replicated in just the same way
58:35
gets replicated in just the same way
58:35
gets replicated in just the same way so in fact we don't need the brackets
58:37
so in fact we don't need the brackets
58:37
so in fact we don't need the brackets here we're done
58:38
here we're done
58:38
here we're done and the shapes are already correct
58:40
and the shapes are already correct
58:40
and the shapes are already correct and finally for the bias
58:43
and finally for the bias
58:43
and finally for the bias very similar this bias here is very very
58:46
very similar this bias here is very very
58:46
very similar this bias here is very very similar to the bias we saw when you
58:47
similar to the bias we saw when you
58:47
similar to the bias we saw when you layer in the linear layer and we see
58:49
layer in the linear layer and we see
58:49
layer in the linear layer and we see that the gradients from each preact will
58:51
that the gradients from each preact will
58:51
that the gradients from each preact will simply flow into the biases and add up
58:54
simply flow into the biases and add up
58:54
simply flow into the biases and add up because these are just these are just
58:55
because these are just these are just
58:55
because these are just these are just offsets
58:56
offsets
58:57
offsets and so basically we want this to be DH
58:59
and so basically we want this to be DH
58:59
and so basically we want this to be DH preact but it needs to Sum along the
59:01
preact but it needs to Sum along the
59:01
preact but it needs to Sum along the right Dimension and in this case similar
59:04
right Dimension and in this case similar
59:04
right Dimension and in this case similar to the gain we need to sum across the
59:06
to the gain we need to sum across the
59:06
to the gain we need to sum across the zeroth dimension the examples because of
59:08
zeroth dimension the examples because of
59:09
zeroth dimension the examples because of the way that the bias gets replicated
59:10
the way that the bias gets replicated
59:10
the way that the bias gets replicated vertically
59:11
vertically
59:11
vertically and we also want to have keep them as
59:14
and we also want to have keep them as
59:14
and we also want to have keep them as true
59:15
true
59:15
true and so this will basically take this and
59:17
and so this will basically take this and
59:17
and so this will basically take this and sum it up and give us a 1 by 64.
59:20
sum it up and give us a 1 by 64.
59:20
sum it up and give us a 1 by 64. so this is the candidate implementation
59:23
so this is the candidate implementation
59:23
so this is the candidate implementation it makes all the shapes work
59:25
it makes all the shapes work
59:25
it makes all the shapes work let me bring it up down here and then
59:28
let me bring it up down here and then
59:28
let me bring it up down here and then let me uncomment these three lines
59:32
let me uncomment these three lines
59:32
let me uncomment these three lines to check that we are getting the correct
59:33
to check that we are getting the correct
59:33
to check that we are getting the correct result for all the three tensors and
59:36
result for all the three tensors and
59:36
result for all the three tensors and indeed we see that all of that got back
59:38
indeed we see that all of that got back
59:38
indeed we see that all of that got back propagated correctly so now we get to
59:40
propagated correctly so now we get to
59:40
propagated correctly so now we get to the batch Norm layer we see how here
59:42
the batch Norm layer we see how here
59:42
the batch Norm layer we see how here being gay and being bias are the
59:44
being gay and being bias are the
59:44
being gay and being bias are the parameters so the back propagation ends
59:46
parameters so the back propagation ends
59:46
parameters so the back propagation ends but B and raw now is the output of the
59:50
but B and raw now is the output of the
59:50
but B and raw now is the output of the standardization
59:51
standardization
59:51
standardization so here what I'm doing of course is I'm
59:53
so here what I'm doing of course is I'm
59:53
so here what I'm doing of course is I'm breaking up the batch form into
59:54
breaking up the batch form into
59:54
breaking up the batch form into manageable pieces so we can back
59:55
manageable pieces so we can back
59:55
manageable pieces so we can back propagate through each line individually
59:57
propagate through each line individually
59:57
propagate through each line individually but basically what's happening is BN
1:00:00
but basically what's happening is BN
1:00:00
but basically what's happening is BN mean I is the sum
1:00:03
mean I is the sum
1:00:03
mean I is the sum so this is the B and mean I I apologize
1:00:06
so this is the B and mean I I apologize
1:00:06
so this is the B and mean I I apologize for the variable naming B and diff is x
1:00:10
for the variable naming B and diff is x
1:00:10
for the variable naming B and diff is x minus mu
1:00:11
minus mu
1:00:11
minus mu B and div 2 is x minus mu squared here
1:00:14
B and div 2 is x minus mu squared here
1:00:15
B and div 2 is x minus mu squared here inside the variance
1:00:16
inside the variance
1:00:16
inside the variance B and VAR is the variance so uh Sigma
1:00:20
B and VAR is the variance so uh Sigma
1:00:20
B and VAR is the variance so uh Sigma Square this is B and bar and it's
1:00:22
Square this is B and bar and it's
1:00:22
Square this is B and bar and it's basically the sum of squares
1:00:25
basically the sum of squares
1:00:25
basically the sum of squares so this is the x minus mu squared and
1:00:28
so this is the x minus mu squared and
1:00:28
so this is the x minus mu squared and then the sum now you'll notice one
1:00:30
then the sum now you'll notice one
1:00:30
then the sum now you'll notice one departure here
1:00:32
departure here
1:00:32
departure here here it is normalized as 1 over m
1:00:34
here it is normalized as 1 over m
1:00:34
here it is normalized as 1 over m uh which is number of examples here I'm
1:00:37
uh which is number of examples here I'm
1:00:37
uh which is number of examples here I'm normalizing as one over n minus 1
1:00:39
normalizing as one over n minus 1
1:00:39
normalizing as one over n minus 1 instead of N and this is deliberate and
1:00:41
instead of N and this is deliberate and
1:00:42
instead of N and this is deliberate and I'll come back to that in a bit when we
1:00:43
I'll come back to that in a bit when we
1:00:43
I'll come back to that in a bit when we are at this line it is something called
1:00:45
are at this line it is something called
1:00:45
are at this line it is something called the bezels correction
1:00:47
the bezels correction
1:00:47
the bezels correction but this is how I want it in our case
1:00:51
but this is how I want it in our case
1:00:51
but this is how I want it in our case bienvar inv then becomes basically
1:00:53
bienvar inv then becomes basically
1:00:53
bienvar inv then becomes basically bienvar plus Epsilon Epsilon is one
1:00:56
bienvar plus Epsilon Epsilon is one
1:00:56
bienvar plus Epsilon Epsilon is one negative five and then it's one over
1:00:58
negative five and then it's one over
1:00:58
negative five and then it's one over square root
1:00:59
square root
1:00:59
square root is the same as raising to the power of
1:01:02
is the same as raising to the power of
1:01:02
is the same as raising to the power of negative 0.5 right because 0.5 is square
1:01:05
negative 0.5 right because 0.5 is square
1:01:05
negative 0.5 right because 0.5 is square root and then negative makes it one over
1:01:07
root and then negative makes it one over
1:01:07
root and then negative makes it one over square root
1:01:08
square root
1:01:08
square root so BM Bar M is a one over this uh
1:01:12
so BM Bar M is a one over this uh
1:01:12
so BM Bar M is a one over this uh denominator here and then we can see
1:01:14
denominator here and then we can see
1:01:14
denominator here and then we can see that b and raw which is the X hat here
1:01:16
that b and raw which is the X hat here
1:01:16
that b and raw which is the X hat here is equal to the BN diff the numerator
1:01:19
is equal to the BN diff the numerator
1:01:19
is equal to the BN diff the numerator multiplied by the
1:01:22
multiplied by the
1:01:22
multiplied by the um BN bar in
1:01:24
um BN bar in
1:01:24
um BN bar in and this line here that creates pre-h
1:01:27
and this line here that creates pre-h
1:01:27
and this line here that creates pre-h pre-act was the last piece we've already
1:01:29
pre-act was the last piece we've already
1:01:29
pre-act was the last piece we've already back propagated through it
1:01:31
back propagated through it
1:01:31
back propagated through it so now what we want to do is we are here
1:01:34
so now what we want to do is we are here
1:01:34
so now what we want to do is we are here and we have B and raw and we have to
1:01:35
and we have B and raw and we have to
1:01:35
and we have B and raw and we have to first back propagate into B and diff and
1:01:38
first back propagate into B and diff and
1:01:38
first back propagate into B and diff and B and Bar M
1:01:40
B and Bar M
1:01:40
B and Bar M so now we're here and we have DB and raw
1:01:43
so now we're here and we have DB and raw
1:01:43
so now we're here and we have DB and raw and we need to back propagate through
1:01:45
and we need to back propagate through
1:01:45
and we need to back propagate through this line
1:01:46
this line
1:01:46
this line now I've written out the shapes here and
1:01:49
now I've written out the shapes here and
1:01:49
now I've written out the shapes here and indeed bien VAR m is a shape 1 by 64. so
1:01:53
indeed bien VAR m is a shape 1 by 64. so
1:01:53
indeed bien VAR m is a shape 1 by 64. so there is a broadcasting happening here
1:01:55
there is a broadcasting happening here
1:01:55
there is a broadcasting happening here that we have to be careful with but it
1:01:57
that we have to be careful with but it
1:01:57
that we have to be careful with but it is just an element-wise simple
1:01:58
is just an element-wise simple
1:01:58
is just an element-wise simple multiplication by now we should be
1:02:00
multiplication by now we should be
1:02:00
multiplication by now we should be pretty comfortable with that to get DB
1:02:02
pretty comfortable with that to get DB
1:02:02
pretty comfortable with that to get DB and diff we know that this is just B and
1:02:05
and diff we know that this is just B and
1:02:05
and diff we know that this is just B and varm
1:02:06
varm
1:02:06
varm multiplied with
1:02:08
multiplied with
1:02:08
multiplied with DP and raw
1:02:11
and conversely to get dbmring
1:02:14
and conversely to get dbmring
1:02:15
and conversely to get dbmring we need to take the end if
1:02:17
we need to take the end if
1:02:17
we need to take the end if and multiply that by DB and raw
1:02:22
so this is the candidate but of course
1:02:24
so this is the candidate but of course
1:02:24
so this is the candidate but of course we need to make sure that broadcasting
1:02:26
we need to make sure that broadcasting
1:02:26
we need to make sure that broadcasting is obeyed so in particular B and VAR M
1:02:29
is obeyed so in particular B and VAR M
1:02:29
is obeyed so in particular B and VAR M multiplying with DB and raw
1:02:31
multiplying with DB and raw
1:02:31
multiplying with DB and raw will be okay and give us 32 by 64 as we
1:02:35
will be okay and give us 32 by 64 as we
1:02:35
will be okay and give us 32 by 64 as we expect
1:02:36
expect
1:02:36
expect but dbm VAR inv would be taking a 32 by
1:02:40
but dbm VAR inv would be taking a 32 by
1:02:40
but dbm VAR inv would be taking a 32 by 64.
1:02:42
64.
1:02:42
64. multiplying it by 32 by 64. so this is a
1:02:45
multiplying it by 32 by 64. so this is a
1:02:45
multiplying it by 32 by 64. so this is a 32 by 64. but of course DB this uh B and
1:02:49
32 by 64. but of course DB this uh B and
1:02:49
32 by 64. but of course DB this uh B and VAR in is only 1 by 64. so the second
1:02:52
VAR in is only 1 by 64. so the second
1:02:52
VAR in is only 1 by 64. so the second line here needs a sum across the
1:02:55
line here needs a sum across the
1:02:55
line here needs a sum across the examples and because there's this
1:02:57
examples and because there's this
1:02:57
examples and because there's this Dimension here we need to make sure that
1:02:59
Dimension here we need to make sure that
1:03:00
Dimension here we need to make sure that keep them is true
1:03:02
keep them is true
1:03:02
keep them is true so this is the candidate
1:03:04
so this is the candidate
1:03:04
so this is the candidate let's erase this and let's swing down
1:03:07
let's erase this and let's swing down
1:03:07
let's erase this and let's swing down here
1:03:09
here
1:03:09
here and implement it and then let's comment
1:03:11
and implement it and then let's comment
1:03:11
and implement it and then let's comment out dbm barif and DB and diff
1:03:16
out dbm barif and DB and diff
1:03:16
out dbm barif and DB and diff now we'll actually notice that DB and
1:03:18
now we'll actually notice that DB and
1:03:18
now we'll actually notice that DB and diff by the way is going to be incorrect
1:03:22
diff by the way is going to be incorrect
1:03:22
diff by the way is going to be incorrect so when I run this
1:03:24
so when I run this
1:03:24
so when I run this BMR m is correct B and diff is not
1:03:27
BMR m is correct B and diff is not
1:03:27
BMR m is correct B and diff is not correct and this is actually expected
1:03:30
correct and this is actually expected
1:03:30
correct and this is actually expected because we're not done with b and diff
1:03:34
because we're not done with b and diff
1:03:34
because we're not done with b and diff so in particular when we slide here we
1:03:35
so in particular when we slide here we
1:03:36
so in particular when we slide here we see here that b and raw as a function of
1:03:37
see here that b and raw as a function of
1:03:37
see here that b and raw as a function of B and diff but actually B and far of is
1:03:40
B and diff but actually B and far of is
1:03:40
B and diff but actually B and far of is a function of B of R which is a function
1:03:42
a function of B of R which is a function
1:03:42
a function of B of R which is a function of B and df2 which is a function of B
1:03:44
of B and df2 which is a function of B
1:03:44
of B and df2 which is a function of B and diff
1:03:45
and diff
1:03:45
and diff so it comes here so bdn diff
1:03:48
so it comes here so bdn diff
1:03:48
so it comes here so bdn diff um these variable names are crazy I'm
1:03:50
um these variable names are crazy I'm
1:03:50
um these variable names are crazy I'm sorry it branches out into two branches
1:03:53
sorry it branches out into two branches
1:03:53
sorry it branches out into two branches and we've only done one branch of it we
1:03:55
and we've only done one branch of it we
1:03:55
and we've only done one branch of it we have to continue our back propagation
1:03:56
have to continue our back propagation
1:03:57
have to continue our back propagation and eventually come back to B and diff
1:03:58
and eventually come back to B and diff
1:03:58
and eventually come back to B and diff and then we'll be able to do a plus
1:04:00
and then we'll be able to do a plus
1:04:00
and then we'll be able to do a plus equals and get the actual card gradient
1:04:02
equals and get the actual card gradient
1:04:02
equals and get the actual card gradient for now it is good to verify that CMP
1:04:05
for now it is good to verify that CMP
1:04:05
for now it is good to verify that CMP also works it doesn't just lie to us and
1:04:07
also works it doesn't just lie to us and
1:04:07
also works it doesn't just lie to us and tell us that everything is always
1:04:08
tell us that everything is always
1:04:08
tell us that everything is always correct it can in fact detect when your
1:04:11
correct it can in fact detect when your
1:04:11
correct it can in fact detect when your gradient is not correct so it's that's
1:04:13
gradient is not correct so it's that's
1:04:13
gradient is not correct so it's that's good to see as well okay so now we have
1:04:15
good to see as well okay so now we have
1:04:15
good to see as well okay so now we have the derivative here and we're trying to
1:04:17
the derivative here and we're trying to
1:04:17
the derivative here and we're trying to back propagate through this line
1:04:18
back propagate through this line
1:04:18
back propagate through this line and because we're raising to a power of
1:04:21
and because we're raising to a power of
1:04:21
and because we're raising to a power of negative 0.5 I brought up the power rule
1:04:23
negative 0.5 I brought up the power rule
1:04:23
negative 0.5 I brought up the power rule and we see that basically we have that
1:04:25
and we see that basically we have that
1:04:25
and we see that basically we have that the BM bar will now be we bring down the
1:04:28
the BM bar will now be we bring down the
1:04:28
the BM bar will now be we bring down the exponent so negative 0.5 times
1:04:31
exponent so negative 0.5 times
1:04:31
exponent so negative 0.5 times uh X which is this
1:04:34
uh X which is this
1:04:34
uh X which is this and now raised to the power of negative
1:04:36
and now raised to the power of negative
1:04:36
and now raised to the power of negative 0.5 minus 1 which is negative 1.5
1:04:39
0.5 minus 1 which is negative 1.5
1:04:39
0.5 minus 1 which is negative 1.5 now we would have to also apply a small
1:04:42
now we would have to also apply a small
1:04:42
now we would have to also apply a small chain rule here in our head because we
1:04:45
chain rule here in our head because we
1:04:45
chain rule here in our head because we need to take further the derivative of B
1:04:47
need to take further the derivative of B
1:04:48
need to take further the derivative of B and VAR with respect to this expression
1:04:49
and VAR with respect to this expression
1:04:49
and VAR with respect to this expression here inside the bracket but because this
1:04:51
here inside the bracket but because this
1:04:51
here inside the bracket but because this is an elementalized operation and
1:04:53
is an elementalized operation and
1:04:53
is an elementalized operation and everything is fairly simple that's just
1:04:54
everything is fairly simple that's just
1:04:54
everything is fairly simple that's just one and so there's nothing to do there
1:04:57
one and so there's nothing to do there
1:04:57
one and so there's nothing to do there so this is the local derivative and then
1:05:00
so this is the local derivative and then
1:05:00
so this is the local derivative and then times the global derivative to create
1:05:01
times the global derivative to create
1:05:01
times the global derivative to create the chain rule this is just times the BM
1:05:04
the chain rule this is just times the BM
1:05:04
the chain rule this is just times the BM bar have
1:05:05
bar have
1:05:05
bar have so this is our candidate let me bring
1:05:08
so this is our candidate let me bring
1:05:08
so this is our candidate let me bring this down
1:05:10
this down
1:05:10
this down and uncommon to the check
1:05:14
and we see that we have the correct
1:05:16
and we see that we have the correct
1:05:16
and we see that we have the correct result
1:05:17
result
1:05:17
result now before we propagate through the next
1:05:19
now before we propagate through the next
1:05:19
now before we propagate through the next line I want to briefly talk about the
1:05:20
line I want to briefly talk about the
1:05:20
line I want to briefly talk about the note here where I'm using the bezels
1:05:22
note here where I'm using the bezels
1:05:22
note here where I'm using the bezels correction dividing by n minus 1 instead
1:05:24
correction dividing by n minus 1 instead
1:05:24
correction dividing by n minus 1 instead of dividing by n when I normalize here
1:05:27
of dividing by n when I normalize here
1:05:27
of dividing by n when I normalize here the sum of squares
1:05:29
the sum of squares
1:05:29
the sum of squares now you'll notice that this is departure
1:05:31
now you'll notice that this is departure
1:05:31
now you'll notice that this is departure from the paper which uses one over n
1:05:33
from the paper which uses one over n
1:05:33
from the paper which uses one over n instead not one over n minus one their m
1:05:36
instead not one over n minus one their m
1:05:36
instead not one over n minus one their m is RN
1:05:38
is RN
1:05:38
is RN and
1:05:39
and
1:05:39
and um so it turns out that there are two
1:05:40
um so it turns out that there are two
1:05:40
um so it turns out that there are two ways of estimating variance of an array
1:05:43
ways of estimating variance of an array
1:05:43
ways of estimating variance of an array one is the biased estimate which is one
1:05:46
one is the biased estimate which is one
1:05:46
one is the biased estimate which is one over n and the other one is the unbiased
1:05:49
over n and the other one is the unbiased
1:05:49
over n and the other one is the unbiased estimate which is one over n minus one
1:05:51
estimate which is one over n minus one
1:05:51
estimate which is one over n minus one now confusingly in the paper this is uh
1:05:54
now confusingly in the paper this is uh
1:05:54
now confusingly in the paper this is uh not very clearly described and also it's
1:05:56
not very clearly described and also it's
1:05:56
not very clearly described and also it's a detail that kind of matters I think
1:05:58
a detail that kind of matters I think
1:05:58
a detail that kind of matters I think um they are using the biased version
1:06:00
um they are using the biased version
1:06:00
um they are using the biased version training time but later when they are
1:06:02
training time but later when they are
1:06:02
training time but later when they are talking about the inference they are
1:06:04
talking about the inference they are
1:06:04
talking about the inference they are mentioning that when they do the
1:06:05
mentioning that when they do the
1:06:06
mentioning that when they do the inference they are using the unbiased
1:06:08
inference they are using the unbiased
1:06:08
inference they are using the unbiased estimate which is the n minus one
1:06:10
estimate which is the n minus one
1:06:10
estimate which is the n minus one version in
1:06:12
version in
1:06:12
version in um
1:06:12
um
1:06:12
um basically for inference
1:06:15
basically for inference
1:06:15
basically for inference and to calibrate the running mean and
1:06:18
and to calibrate the running mean and
1:06:18
and to calibrate the running mean and the running variance basically and so
1:06:20
the running variance basically and so
1:06:20
the running variance basically and so they they actually introduce a trained
1:06:22
they they actually introduce a trained
1:06:22
they they actually introduce a trained test mismatch where in training they use
1:06:24
test mismatch where in training they use
1:06:24
test mismatch where in training they use the biased version and in the in test
1:06:26
the biased version and in the in test
1:06:26
the biased version and in the in test time they use the unbiased version I
1:06:28
time they use the unbiased version I
1:06:28
time they use the unbiased version I find this extremely confusing you can
1:06:30
find this extremely confusing you can
1:06:30
find this extremely confusing you can read more about the bezels correction
1:06:32
read more about the bezels correction
1:06:32
read more about the bezels correction and why uh dividing by n minus one gives
1:06:35
and why uh dividing by n minus one gives
1:06:35
and why uh dividing by n minus one gives you a better estimate of the variance in
1:06:37
you a better estimate of the variance in
1:06:37
you a better estimate of the variance in a case where you have population size or
1:06:39
a case where you have population size or
1:06:39
a case where you have population size or samples for the population
1:06:40
samples for the population
1:06:41
samples for the population that are very small and that is indeed
1:06:44
that are very small and that is indeed
1:06:44
that are very small and that is indeed the case for us because we are dealing
1:06:46
the case for us because we are dealing
1:06:46
the case for us because we are dealing with many patches and these mini matches
1:06:48
with many patches and these mini matches
1:06:48
with many patches and these mini matches are a small sample of a larger
1:06:50
are a small sample of a larger
1:06:50
are a small sample of a larger population which is the entire training
1:06:52
population which is the entire training
1:06:52
population which is the entire training set and so it just turns out that if you
1:06:55
set and so it just turns out that if you
1:06:55
set and so it just turns out that if you just estimate it using one over n that
1:06:57
just estimate it using one over n that
1:06:57
just estimate it using one over n that actually almost always underestimates
1:06:58
actually almost always underestimates
1:06:58
actually almost always underestimates the variance and it is a biased
1:07:00
the variance and it is a biased
1:07:00
the variance and it is a biased estimator and it is advised that you use
1:07:02
estimator and it is advised that you use
1:07:02
estimator and it is advised that you use the unbiased version and divide by n
1:07:04
the unbiased version and divide by n
1:07:04
the unbiased version and divide by n minus one and you can go through this
1:07:06
minus one and you can go through this
1:07:06
minus one and you can go through this article here that I liked that actually
1:07:08
article here that I liked that actually
1:07:08
article here that I liked that actually describes the full reasoning and I'll
1:07:09
describes the full reasoning and I'll
1:07:09
describes the full reasoning and I'll link it in the video description
1:07:12
link it in the video description
1:07:12
link it in the video description now when you calculate the torture
1:07:13
now when you calculate the torture
1:07:13
now when you calculate the torture variance
1:07:15
variance
1:07:15
variance you'll notice that they take the
1:07:16
you'll notice that they take the
1:07:16
you'll notice that they take the unbiased flag whether or not you want to
1:07:18
unbiased flag whether or not you want to
1:07:18
unbiased flag whether or not you want to divide by n or n minus one confusingly
1:07:21
divide by n or n minus one confusingly
1:07:21
divide by n or n minus one confusingly they do not mention what the default is
1:07:24
they do not mention what the default is
1:07:24
they do not mention what the default is for unbiased but I believe unbiased by
1:07:26
for unbiased but I believe unbiased by
1:07:26
for unbiased but I believe unbiased by default is true I'm not sure why the
1:07:29
default is true I'm not sure why the
1:07:29
default is true I'm not sure why the docs here don't cite that
1:07:31
docs here don't cite that
1:07:31
docs here don't cite that now in The Bachelor
1:07:33
now in The Bachelor
1:07:33
now in The Bachelor 1D the documentation again is kind of
1:07:35
1D the documentation again is kind of
1:07:35
1D the documentation again is kind of wrong and confusing it says that the
1:07:38
wrong and confusing it says that the
1:07:38
wrong and confusing it says that the standard deviation is calculated via the
1:07:39
standard deviation is calculated via the
1:07:39
standard deviation is calculated via the biased estimator
1:07:41
biased estimator
1:07:41
biased estimator but this is actually not exactly right
1:07:43
but this is actually not exactly right
1:07:43
but this is actually not exactly right and people have pointed out that it is
1:07:44
and people have pointed out that it is
1:07:44
and people have pointed out that it is not right in a number of issues since
1:07:46
not right in a number of issues since
1:07:46
not right in a number of issues since then because actually the rabbit hole is
1:07:49
then because actually the rabbit hole is
1:07:49
then because actually the rabbit hole is deeper and they follow the paper exactly
1:07:52
deeper and they follow the paper exactly
1:07:52
deeper and they follow the paper exactly and they use the biased version for
1:07:54
and they use the biased version for
1:07:54
and they use the biased version for training but when they're estimating the
1:07:56
training but when they're estimating the
1:07:56
training but when they're estimating the running standard deviation we are using
1:07:58
running standard deviation we are using
1:07:58
running standard deviation we are using the unbiased version so again there's
1:08:00
the unbiased version so again there's
1:08:00
the unbiased version so again there's the train test mismatch so long story
1:08:02
the train test mismatch so long story
1:08:02
the train test mismatch so long story short I'm not a fan of trained test
1:08:05
short I'm not a fan of trained test
1:08:05
short I'm not a fan of trained test discrepancies I basically kind of
1:08:07
discrepancies I basically kind of
1:08:07
discrepancies I basically kind of consider
1:08:08
consider
1:08:08
consider the fact that we use the bias version
1:08:10
the fact that we use the bias version
1:08:10
the fact that we use the bias version the training time and the unbiased test
1:08:13
the training time and the unbiased test
1:08:13
the training time and the unbiased test time I basically consider this to be a
1:08:14
time I basically consider this to be a
1:08:14
time I basically consider this to be a bug and I don't think that there's a
1:08:16
bug and I don't think that there's a
1:08:16
bug and I don't think that there's a good reason for that it's not really
1:08:18
good reason for that it's not really
1:08:18
good reason for that it's not really they don't really go into the detail of
1:08:19
they don't really go into the detail of
1:08:19
they don't really go into the detail of the reasoning behind it in this paper so
1:08:22
the reasoning behind it in this paper so
1:08:22
the reasoning behind it in this paper so that's why I basically prefer to use the
1:08:24
that's why I basically prefer to use the
1:08:24
that's why I basically prefer to use the bestless correction in my own work
1:08:26
bestless correction in my own work
1:08:26
bestless correction in my own work unfortunately Bastion does not take a
1:08:29
unfortunately Bastion does not take a
1:08:29
unfortunately Bastion does not take a keyword argument that tells you whether
1:08:30
keyword argument that tells you whether
1:08:30
keyword argument that tells you whether or not you want to use the unbiased
1:08:33
or not you want to use the unbiased
1:08:33
or not you want to use the unbiased version of the bias version in both
1:08:34
version of the bias version in both
1:08:34
version of the bias version in both train and test and so therefore anyone
1:08:36
train and test and so therefore anyone
1:08:36
train and test and so therefore anyone using batch normalization basically in
1:08:38
using batch normalization basically in
1:08:38
using batch normalization basically in my view has a bit of a bug in the code
1:08:41
my view has a bit of a bug in the code
1:08:41
my view has a bit of a bug in the code um
1:08:41
um
1:08:42
um and this turns out to be much less of a
1:08:44
and this turns out to be much less of a
1:08:44
and this turns out to be much less of a problem if your batch mini batch sizes
1:08:46
problem if your batch mini batch sizes
1:08:46
problem if your batch mini batch sizes are a bit larger but still I just might
1:08:48
are a bit larger but still I just might
1:08:48
are a bit larger but still I just might kind of uh unpardable so maybe someone
1:08:51
kind of uh unpardable so maybe someone
1:08:51
kind of uh unpardable so maybe someone can explain why this is okay but for now
1:08:53
can explain why this is okay but for now
1:08:53
can explain why this is okay but for now I prefer to use the unbiased version
1:08:55
I prefer to use the unbiased version
1:08:55
I prefer to use the unbiased version consistently both during training and at
1:08:57
consistently both during training and at
1:08:57
consistently both during training and at this time and that's why I'm using one
1:09:00
this time and that's why I'm using one
1:09:00
this time and that's why I'm using one over n minus one here
1:09:01
over n minus one here
1:09:01
over n minus one here okay so let's now actually back
1:09:03
okay so let's now actually back
1:09:03
okay so let's now actually back propagate through this line
1:09:05
propagate through this line
1:09:05
propagate through this line so
1:09:07
so
1:09:07
so the first thing that I always like to do
1:09:08
the first thing that I always like to do
1:09:08
the first thing that I always like to do is I like to scrutinize the shapes first
1:09:10
is I like to scrutinize the shapes first
1:09:10
is I like to scrutinize the shapes first so in particular here looking at the
1:09:12
so in particular here looking at the
1:09:12
so in particular here looking at the shapes of what's involved I see that b
1:09:14
shapes of what's involved I see that b
1:09:14
shapes of what's involved I see that b and VAR shape is 1 by 64. so it's a row
1:09:18
and VAR shape is 1 by 64. so it's a row
1:09:18
and VAR shape is 1 by 64. so it's a row vector and BND if two dot shape is 32 by
1:09:21
vector and BND if two dot shape is 32 by
1:09:21
vector and BND if two dot shape is 32 by 64.
1:09:22
64.
1:09:22
64. so clearly here we're doing a sum over
1:09:25
so clearly here we're doing a sum over
1:09:25
so clearly here we're doing a sum over the zeroth axis to squash the first
1:09:28
the zeroth axis to squash the first
1:09:28
the zeroth axis to squash the first dimension of of the shapes here using a
1:09:32
dimension of of the shapes here using a
1:09:32
dimension of of the shapes here using a sum so that right away actually hints to
1:09:35
sum so that right away actually hints to
1:09:35
sum so that right away actually hints to me that there will be some kind of a
1:09:36
me that there will be some kind of a
1:09:36
me that there will be some kind of a replication or broadcasting in the
1:09:38
replication or broadcasting in the
1:09:38
replication or broadcasting in the backward pass and maybe you're noticing
1:09:40
backward pass and maybe you're noticing
1:09:40
backward pass and maybe you're noticing the pattern here but basically anytime
1:09:42
the pattern here but basically anytime
1:09:42
the pattern here but basically anytime you have a sum in the forward pass that
1:09:45
you have a sum in the forward pass that
1:09:45
you have a sum in the forward pass that turns into a replication or broadcasting
1:09:47
turns into a replication or broadcasting
1:09:47
turns into a replication or broadcasting in the backward pass along the same
1:09:49
in the backward pass along the same
1:09:49
in the backward pass along the same Dimension and conversely when we have a
1:09:52
Dimension and conversely when we have a
1:09:52
Dimension and conversely when we have a replication or a broadcasting in the
1:09:54
replication or a broadcasting in the
1:09:54
replication or a broadcasting in the forward pass that indicates a variable
1:09:57
forward pass that indicates a variable
1:09:57
forward pass that indicates a variable reuse and so in the backward pass that
1:09:59
reuse and so in the backward pass that
1:09:59
reuse and so in the backward pass that turns into a sum over the exact same
1:10:01
turns into a sum over the exact same
1:10:01
turns into a sum over the exact same dimension
1:10:02
dimension
1:10:02
dimension and so hopefully you're noticing that
1:10:04
and so hopefully you're noticing that
1:10:04
and so hopefully you're noticing that Duality that those two are kind of like
1:10:06
Duality that those two are kind of like
1:10:06
Duality that those two are kind of like the opposite of each other in the
1:10:07
the opposite of each other in the
1:10:07
the opposite of each other in the forward and backward pass
1:10:09
forward and backward pass
1:10:09
forward and backward pass now once we understand the shapes the
1:10:11
now once we understand the shapes the
1:10:11
now once we understand the shapes the next thing I like to do always is I like
1:10:12
next thing I like to do always is I like
1:10:12
next thing I like to do always is I like to look at a toy example in my head to
1:10:14
to look at a toy example in my head to
1:10:15
to look at a toy example in my head to sort of just like understand roughly how
1:10:16
sort of just like understand roughly how
1:10:16
sort of just like understand roughly how uh the variable the variable
1:10:18
uh the variable the variable
1:10:18
uh the variable the variable dependencies go in the mathematical
1:10:19
dependencies go in the mathematical
1:10:19
dependencies go in the mathematical formula
1:10:21
formula
1:10:21
formula so here we have a two-dimensional array
1:10:24
so here we have a two-dimensional array
1:10:24
so here we have a two-dimensional array of the end of two which we are scaling
1:10:26
of the end of two which we are scaling
1:10:26
of the end of two which we are scaling by a constant and then we are summing uh
1:10:29
by a constant and then we are summing uh
1:10:29
by a constant and then we are summing uh vertically over the columns so if we
1:10:32
vertically over the columns so if we
1:10:32
vertically over the columns so if we have a two by two Matrix a and then we
1:10:33
have a two by two Matrix a and then we
1:10:33
have a two by two Matrix a and then we sum over the columns and scale we would
1:10:36
sum over the columns and scale we would
1:10:36
sum over the columns and scale we would get a row Vector B1 B2 and B1 depends on
1:10:39
get a row Vector B1 B2 and B1 depends on
1:10:39
get a row Vector B1 B2 and B1 depends on a in this way whereas just sum they're
1:10:42
a in this way whereas just sum they're
1:10:42
a in this way whereas just sum they're scaled of a and B2 in this way where
1:10:45
scaled of a and B2 in this way where
1:10:45
scaled of a and B2 in this way where it's the second column sump and scale
1:10:48
it's the second column sump and scale
1:10:48
it's the second column sump and scale and so looking at this basically
1:10:52
and so looking at this basically
1:10:52
and so looking at this basically what we want to do now is we have the
1:10:53
what we want to do now is we have the
1:10:53
what we want to do now is we have the derivatives on B1 and B2 and we want to
1:10:55
derivatives on B1 and B2 and we want to
1:10:55
derivatives on B1 and B2 and we want to back propagate them into Ace and so it's
1:10:58
back propagate them into Ace and so it's
1:10:58
back propagate them into Ace and so it's clear that just differentiating in your
1:10:59
clear that just differentiating in your
1:10:59
clear that just differentiating in your head the local derivative here is one
1:11:01
head the local derivative here is one
1:11:01
head the local derivative here is one over n minus 1 times uh one
1:11:05
over n minus 1 times uh one
1:11:05
over n minus 1 times uh one uh for each one of these A's and um
1:11:09
uh for each one of these A's and um
1:11:09
uh for each one of these A's and um basically the derivative of B1 has to
1:11:11
basically the derivative of B1 has to
1:11:11
basically the derivative of B1 has to flow through The Columns of a
1:11:13
flow through The Columns of a
1:11:13
flow through The Columns of a scaled by one over n minus one
1:11:16
scaled by one over n minus one
1:11:16
scaled by one over n minus one and that's roughly What's Happening Here
1:11:18
and that's roughly What's Happening Here
1:11:18
and that's roughly What's Happening Here so intuitively the derivative flow tells
1:11:21
so intuitively the derivative flow tells
1:11:21
so intuitively the derivative flow tells us that DB and diff2
1:11:24
us that DB and diff2
1:11:24
us that DB and diff2 will be the local derivative of this
1:11:27
will be the local derivative of this
1:11:27
will be the local derivative of this operation and there are many ways to do
1:11:29
operation and there are many ways to do
1:11:29
operation and there are many ways to do this by the way but I like to do
1:11:31
this by the way but I like to do
1:11:31
this by the way but I like to do something like this torch dot once like
1:11:33
something like this torch dot once like
1:11:33
something like this torch dot once like of bndf2 so I'll create a large array
1:11:37
of bndf2 so I'll create a large array
1:11:37
of bndf2 so I'll create a large array two-dimensional of ones
1:11:39
two-dimensional of ones
1:11:39
two-dimensional of ones and then I will scale it so 1.0 divided
1:11:42
and then I will scale it so 1.0 divided
1:11:42
and then I will scale it so 1.0 divided by n minus 1.
1:11:44
by n minus 1.
1:11:44
by n minus 1. so this is a array of
1:11:46
so this is a array of
1:11:46
so this is a array of um one over n minus one and that's sort
1:11:49
um one over n minus one and that's sort
1:11:49
um one over n minus one and that's sort of like the local derivative
1:11:50
of like the local derivative
1:11:50
of like the local derivative and now for the chain rule I will simply
1:11:53
and now for the chain rule I will simply
1:11:53
and now for the chain rule I will simply just multiply it by dbm bar
1:11:58
and notice here what's going to happen
1:11:59
and notice here what's going to happen
1:12:00
and notice here what's going to happen this is 32 by 64 and this is just 1 by
1:12:02
this is 32 by 64 and this is just 1 by
1:12:02
this is 32 by 64 and this is just 1 by 64. so I'm letting the broadcasting do
1:12:06
64. so I'm letting the broadcasting do
1:12:06
64. so I'm letting the broadcasting do the replication because internally in
1:12:08
the replication because internally in
1:12:08
the replication because internally in pytorch basically dbnbar which is 1 by
1:12:11
pytorch basically dbnbar which is 1 by
1:12:11
pytorch basically dbnbar which is 1 by 64 row vector
1:12:13
64 row vector
1:12:13
64 row vector well in this multiplication get
1:12:15
well in this multiplication get
1:12:15
well in this multiplication get um copied vertically until the two are
1:12:18
um copied vertically until the two are
1:12:18
um copied vertically until the two are of the same shape and then there will be
1:12:19
of the same shape and then there will be
1:12:19
of the same shape and then there will be an element wise multiply and so that uh
1:12:22
an element wise multiply and so that uh
1:12:22
an element wise multiply and so that uh so that the broadcasting is basically
1:12:23
so that the broadcasting is basically
1:12:23
so that the broadcasting is basically doing the replication
1:12:25
doing the replication
1:12:25
doing the replication and I will end up with the derivatives
1:12:27
and I will end up with the derivatives
1:12:27
and I will end up with the derivatives of DB and diff2 here
1:12:30
of DB and diff2 here
1:12:30
of DB and diff2 here so this is the candidate solution let's
1:12:32
so this is the candidate solution let's
1:12:32
so this is the candidate solution let's bring it down here
1:12:33
bring it down here
1:12:33
bring it down here let's uncomment this line where we check
1:12:36
let's uncomment this line where we check
1:12:36
let's uncomment this line where we check it and let's hope for the best
1:12:39
it and let's hope for the best
1:12:39
it and let's hope for the best and indeed we see that this is the
1:12:41
and indeed we see that this is the
1:12:41
and indeed we see that this is the correct formula next up let's
1:12:43
correct formula next up let's
1:12:43
correct formula next up let's differentiate here and to be in this
1:12:45
differentiate here and to be in this
1:12:45
differentiate here and to be in this so here we have that b and diff is
1:12:47
so here we have that b and diff is
1:12:48
so here we have that b and diff is element y squared to create B and F2
1:12:50
element y squared to create B and F2
1:12:50
element y squared to create B and F2 so this is a relatively simple
1:12:52
so this is a relatively simple
1:12:52
so this is a relatively simple derivative because it's a simple element
1:12:54
derivative because it's a simple element
1:12:54
derivative because it's a simple element wise operation so it's kind of like the
1:12:56
wise operation so it's kind of like the
1:12:56
wise operation so it's kind of like the scalar case and we have that DB and div
1:12:59
scalar case and we have that DB and div
1:12:59
scalar case and we have that DB and div should be if this is x squared then the
1:13:02
should be if this is x squared then the
1:13:02
should be if this is x squared then the derivative of this is 2x right so it's
1:13:04
derivative of this is 2x right so it's
1:13:04
derivative of this is 2x right so it's simply 2 times B and if that's the local
1:13:07
simply 2 times B and if that's the local
1:13:07
simply 2 times B and if that's the local derivative
1:13:08
derivative
1:13:08
derivative and then times chain Rule and the shape
1:13:11
and then times chain Rule and the shape
1:13:11
and then times chain Rule and the shape of these is the same they are of the
1:13:13
of these is the same they are of the
1:13:13
of these is the same they are of the same shape so times this
1:13:15
same shape so times this
1:13:15
same shape so times this so that's the backward pass for this
1:13:17
so that's the backward pass for this
1:13:17
so that's the backward pass for this variable let me bring that down here
1:13:20
variable let me bring that down here
1:13:20
variable let me bring that down here and now we have to be careful because we
1:13:22
and now we have to be careful because we
1:13:22
and now we have to be careful because we already calculated dbm depth right so
1:13:24
already calculated dbm depth right so
1:13:24
already calculated dbm depth right so this is just the end of the other uh you
1:13:27
this is just the end of the other uh you
1:13:27
this is just the end of the other uh you know other Branch coming back to B and
1:13:30
know other Branch coming back to B and
1:13:30
know other Branch coming back to B and diff
1:13:30
diff
1:13:30
diff because B and diff was already back
1:13:32
because B and diff was already back
1:13:32
because B and diff was already back propagated to way over here
1:13:34
propagated to way over here
1:13:34
propagated to way over here from being raw so we now completed the
1:13:37
from being raw so we now completed the
1:13:37
from being raw so we now completed the second branch and so that's why I have
1:13:39
second branch and so that's why I have
1:13:39
second branch and so that's why I have to do plus equals and if you recall we
1:13:42
to do plus equals and if you recall we
1:13:42
to do plus equals and if you recall we had an incorrect derivative for being
1:13:43
had an incorrect derivative for being
1:13:43
had an incorrect derivative for being diff before and I'm hoping that once we
1:13:46
diff before and I'm hoping that once we
1:13:46
diff before and I'm hoping that once we append this last missing piece we have
1:13:48
append this last missing piece we have
1:13:48
append this last missing piece we have the exact correctness so let's run
1:13:51
the exact correctness so let's run
1:13:51
the exact correctness so let's run ambient to be in div now actually shows
1:13:55
ambient to be in div now actually shows
1:13:55
ambient to be in div now actually shows the exact correct derivative
1:13:57
the exact correct derivative
1:13:57
the exact correct derivative um so that's comforting okay so let's
1:13:59
um so that's comforting okay so let's
1:14:00
um so that's comforting okay so let's now back propagate through this line
1:14:01
now back propagate through this line
1:14:01
now back propagate through this line here
1:14:03
here
1:14:03
here um the first thing we do of course is we
1:14:04
um the first thing we do of course is we
1:14:04
um the first thing we do of course is we check the shapes and I wrote them out
1:14:07
check the shapes and I wrote them out
1:14:07
check the shapes and I wrote them out here and basically the shape of this is
1:14:08
here and basically the shape of this is
1:14:08
here and basically the shape of this is 32 by 64. hpbn is the same shape
1:14:12
32 by 64. hpbn is the same shape
1:14:12
32 by 64. hpbn is the same shape but B and mean I is a row Vector 1 by
1:14:15
but B and mean I is a row Vector 1 by
1:14:15
but B and mean I is a row Vector 1 by 64. so this minus here will actually do
1:14:17
64. so this minus here will actually do
1:14:17
64. so this minus here will actually do broadcasting and so we have to be
1:14:19
broadcasting and so we have to be
1:14:19
broadcasting and so we have to be careful with that and as a hint to us
1:14:21
careful with that and as a hint to us
1:14:21
careful with that and as a hint to us again because of The Duality a
1:14:23
again because of The Duality a
1:14:23
again because of The Duality a broadcasting and the forward pass means
1:14:25
broadcasting and the forward pass means
1:14:25
broadcasting and the forward pass means a variable reuse and therefore there
1:14:27
a variable reuse and therefore there
1:14:27
a variable reuse and therefore there will be a sum in the backward pass
1:14:30
will be a sum in the backward pass
1:14:30
will be a sum in the backward pass so let's write out the backward pass
1:14:31
so let's write out the backward pass
1:14:31
so let's write out the backward pass here now
1:14:33
here now
1:14:33
here now um
1:14:34
um
1:14:34
um back propagate into the hpbn
1:14:37
back propagate into the hpbn
1:14:37
back propagate into the hpbn because this is these are the same shape
1:14:39
because this is these are the same shape
1:14:39
because this is these are the same shape then the local derivative for each one
1:14:41
then the local derivative for each one
1:14:41
then the local derivative for each one of the elements here is just one for the
1:14:43
of the elements here is just one for the
1:14:43
of the elements here is just one for the corresponding element in here
1:14:45
corresponding element in here
1:14:45
corresponding element in here so basically what this means is that the
1:14:47
so basically what this means is that the
1:14:47
so basically what this means is that the gradient just simply copies it's just a
1:14:50
gradient just simply copies it's just a
1:14:50
gradient just simply copies it's just a variable assignment it's quality so I'm
1:14:52
variable assignment it's quality so I'm
1:14:52
variable assignment it's quality so I'm just going to clone this tensor just for
1:14:54
just going to clone this tensor just for
1:14:54
just going to clone this tensor just for safety to create an exact copy of DB and
1:14:58
safety to create an exact copy of DB and
1:14:58
safety to create an exact copy of DB and div
1:15:00
div
1:15:00
div and then here to back propagate into
1:15:01
and then here to back propagate into
1:15:01
and then here to back propagate into this one what I'm inclined to do here is
1:15:07
will basically be
1:15:09
will basically be
1:15:09
will basically be uh what is the local derivative well
1:15:12
uh what is the local derivative well
1:15:12
uh what is the local derivative well it's negative torch.1's like
1:15:16
it's negative torch.1's like
1:15:16
it's negative torch.1's like of the shape of uh B and diff
1:15:19
of the shape of uh B and diff
1:15:19
of the shape of uh B and diff right
1:15:22
and then times
1:15:24
and then times
1:15:24
and then times the um
1:15:27
the um
1:15:27
the um the derivative here dbf
1:15:32
and this here is the back propagation
1:15:34
and this here is the back propagation
1:15:34
and this here is the back propagation for the replicated B and mean I
1:15:37
for the replicated B and mean I
1:15:37
for the replicated B and mean I so I still have to back propagate
1:15:39
so I still have to back propagate
1:15:39
so I still have to back propagate through the uh replication in the
1:15:42
through the uh replication in the
1:15:42
through the uh replication in the broadcasting and I do that by doing a
1:15:43
broadcasting and I do that by doing a
1:15:43
broadcasting and I do that by doing a sum so I'm going to take this whole
1:15:45
sum so I'm going to take this whole
1:15:45
sum so I'm going to take this whole thing and I'm going to do a sum over the
1:15:47
thing and I'm going to do a sum over the
1:15:47
thing and I'm going to do a sum over the zeroth dimension which was the
1:15:49
zeroth dimension which was the
1:15:49
zeroth dimension which was the replication
1:15:53
so if you scrutinize this by the way
1:15:55
so if you scrutinize this by the way
1:15:55
so if you scrutinize this by the way you'll notice that this is the same
1:15:57
you'll notice that this is the same
1:15:57
you'll notice that this is the same shape as that and so what I'm doing uh
1:15:59
shape as that and so what I'm doing uh
1:16:00
shape as that and so what I'm doing uh what I'm doing here doesn't actually
1:16:01
what I'm doing here doesn't actually
1:16:01
what I'm doing here doesn't actually make that much sense because it's just a
1:16:03
make that much sense because it's just a
1:16:03
make that much sense because it's just a array of ones multiplying DP and diff so
1:16:06
array of ones multiplying DP and diff so
1:16:06
array of ones multiplying DP and diff so in fact I can just do this
1:16:10
in fact I can just do this
1:16:10
in fact I can just do this um and that is equivalent
1:16:12
um and that is equivalent
1:16:12
um and that is equivalent so this is the candidate backward pass
1:16:15
so this is the candidate backward pass
1:16:15
so this is the candidate backward pass let me copy it here and then let me
1:16:17
let me copy it here and then let me
1:16:18
let me copy it here and then let me comment out this one and this one
1:16:22
comment out this one and this one
1:16:22
comment out this one and this one enter
1:16:24
enter
1:16:24
enter and it's wrong
1:16:27
damn
1:16:29
damn
1:16:29
damn actually sorry this is supposed to be
1:16:31
actually sorry this is supposed to be
1:16:31
actually sorry this is supposed to be wrong and it's supposed to be wrong
1:16:33
wrong and it's supposed to be wrong
1:16:33
wrong and it's supposed to be wrong because
1:16:34
because
1:16:34
because we are back propagating from a b and
1:16:36
we are back propagating from a b and
1:16:36
we are back propagating from a b and diff into hpbn and but we're not done
1:16:39
diff into hpbn and but we're not done
1:16:39
diff into hpbn and but we're not done because B and mean I depends on hpbn and
1:16:43
because B and mean I depends on hpbn and
1:16:43
because B and mean I depends on hpbn and there will be a second portion of that
1:16:44
there will be a second portion of that
1:16:44
there will be a second portion of that derivative coming from this second
1:16:46
derivative coming from this second
1:16:46
derivative coming from this second Branch so we're not done yet and we
1:16:48
Branch so we're not done yet and we
1:16:48
Branch so we're not done yet and we expect it to be incorrect so there you
1:16:50
expect it to be incorrect so there you
1:16:50
expect it to be incorrect so there you go
1:16:50
go
1:16:50
go uh so let's now back propagate from uh B
1:16:53
uh so let's now back propagate from uh B
1:16:53
uh so let's now back propagate from uh B and mean I into hpbn
1:16:56
and mean I into hpbn
1:16:56
and mean I into hpbn um
1:16:57
um
1:16:57
um and so here again we have to be careful
1:16:58
and so here again we have to be careful
1:16:58
and so here again we have to be careful because there's a broadcasting along
1:17:01
because there's a broadcasting along
1:17:01
because there's a broadcasting along um or there's a Sum along the zeroth
1:17:03
um or there's a Sum along the zeroth
1:17:03
um or there's a Sum along the zeroth dimension so this will turn into
1:17:04
dimension so this will turn into
1:17:04
dimension so this will turn into broadcasting in the backward pass now
1:17:06
broadcasting in the backward pass now
1:17:06
broadcasting in the backward pass now and I'm going to go a little bit faster
1:17:08
and I'm going to go a little bit faster
1:17:08
and I'm going to go a little bit faster on this line because it is very similar
1:17:10
on this line because it is very similar
1:17:10
on this line because it is very similar to the line that we had before and
1:17:12
to the line that we had before and
1:17:12
to the line that we had before and multiplies in the past in fact
1:17:14
multiplies in the past in fact
1:17:14
multiplies in the past in fact so the hpbn
1:17:18
so the hpbn
1:17:18
so the hpbn will be
1:17:20
will be
1:17:20
will be the gradient will be scaled by 1 over n
1:17:22
the gradient will be scaled by 1 over n
1:17:22
the gradient will be scaled by 1 over n and then basically this gradient here on
1:17:25
and then basically this gradient here on
1:17:25
and then basically this gradient here on dbn mean I
1:17:27
dbn mean I
1:17:27
dbn mean I is going to be scaled by 1 over n and
1:17:30
is going to be scaled by 1 over n and
1:17:30
is going to be scaled by 1 over n and then it's going to flow across all the
1:17:32
then it's going to flow across all the
1:17:32
then it's going to flow across all the columns and deposit itself into the hpvn
1:17:35
columns and deposit itself into the hpvn
1:17:35
columns and deposit itself into the hpvn so what we want is this thing scaled by
1:17:38
so what we want is this thing scaled by
1:17:38
so what we want is this thing scaled by 1 over n
1:17:39
1 over n
1:17:39
1 over n only put the constant up front here
1:17:43
um
1:17:45
um
1:17:45
um so scale down the gradient and now we
1:17:47
so scale down the gradient and now we
1:17:47
so scale down the gradient and now we need to replicate it across all the um
1:17:51
need to replicate it across all the um
1:17:51
need to replicate it across all the um across all the rows here so we I like to
1:17:55
across all the rows here so we I like to
1:17:55
across all the rows here so we I like to do that by torch.lunslike of basically
1:18:00
do that by torch.lunslike of basically
1:18:00
do that by torch.lunslike of basically um hpbn
1:18:03
um hpbn
1:18:03
um hpbn and I will let the broadcasting do the
1:18:05
and I will let the broadcasting do the
1:18:05
and I will let the broadcasting do the work of replication
1:18:09
work of replication
1:18:09
work of replication so
1:18:14
like that
1:18:16
like that
1:18:16
like that so this is uh the hppn and hopefully
1:18:21
so this is uh the hppn and hopefully
1:18:21
so this is uh the hppn and hopefully we can plus equals that
1:18:27
so this here is broadcasting
1:18:30
so this here is broadcasting
1:18:30
so this here is broadcasting um and then this is the scaling so this
1:18:32
um and then this is the scaling so this
1:18:32
um and then this is the scaling so this should be current
1:18:33
should be current
1:18:33
should be current okay
1:18:35
okay
1:18:35
okay so that completes the back propagation
1:18:37
so that completes the back propagation
1:18:37
so that completes the back propagation of the bathroom layer and we are now
1:18:38
of the bathroom layer and we are now
1:18:38
of the bathroom layer and we are now here let's back propagate through the
1:18:40
here let's back propagate through the
1:18:40
here let's back propagate through the linear layer one here now because
1:18:43
linear layer one here now because
1:18:43
linear layer one here now because everything is getting a little
1:18:44
everything is getting a little
1:18:44
everything is getting a little vertically crazy I copy pasted the line
1:18:46
vertically crazy I copy pasted the line
1:18:46
vertically crazy I copy pasted the line here and let's just back properly
1:18:48
here and let's just back properly
1:18:48
here and let's just back properly through this one line
1:18:50
through this one line
1:18:50
through this one line so first of course we inspect the shapes
1:18:52
so first of course we inspect the shapes
1:18:52
so first of course we inspect the shapes and we see that this is 32 by 64. MCAT
1:18:56
and we see that this is 32 by 64. MCAT
1:18:56
and we see that this is 32 by 64. MCAT is 32 by 30.
1:18:58
is 32 by 30.
1:18:58
is 32 by 30. W1 is 30 30 by 64 and B1 is just 64. so
1:19:04
W1 is 30 30 by 64 and B1 is just 64. so
1:19:04
W1 is 30 30 by 64 and B1 is just 64. so as I mentioned back propagating through
1:19:06
as I mentioned back propagating through
1:19:06
as I mentioned back propagating through linear layers is fairly easy just by
1:19:08
linear layers is fairly easy just by
1:19:08
linear layers is fairly easy just by matching the shapes so let's do that we
1:19:11
matching the shapes so let's do that we
1:19:11
matching the shapes so let's do that we have that dmcat
1:19:14
have that dmcat
1:19:14
have that dmcat should be
1:19:15
should be
1:19:15
should be um some matrix multiplication of dhbn
1:19:18
um some matrix multiplication of dhbn
1:19:18
um some matrix multiplication of dhbn with uh W1 and one transpose thrown in
1:19:21
with uh W1 and one transpose thrown in
1:19:21
with uh W1 and one transpose thrown in there so to make uh MCAT be 32 by 30
1:19:28
there so to make uh MCAT be 32 by 30
1:19:28
there so to make uh MCAT be 32 by 30 I need to take dhpn
1:19:32
I need to take dhpn
1:19:32
I need to take dhpn 32 by 64 and multiply it by w1.
1:19:36
32 by 64 and multiply it by w1.
1:19:36
32 by 64 and multiply it by w1. transpose
1:19:39
to get the only one I need to end up
1:19:43
to get the only one I need to end up
1:19:43
to get the only one I need to end up with 30 by 64.
1:19:45
with 30 by 64.
1:19:45
with 30 by 64. so to get that I need to take uh MCAT
1:19:48
so to get that I need to take uh MCAT
1:19:48
so to get that I need to take uh MCAT transpose
1:19:51
transpose
1:19:51
transpose and multiply that by
1:19:53
and multiply that by
1:19:53
and multiply that by uh dhpion
1:19:58
and finally to get DB1
1:20:01
and finally to get DB1
1:20:01
and finally to get DB1 this is a addition and we saw that
1:20:04
this is a addition and we saw that
1:20:04
this is a addition and we saw that basically I need to just sum the
1:20:06
basically I need to just sum the
1:20:06
basically I need to just sum the elements in dhpbn along some Dimension
1:20:09
elements in dhpbn along some Dimension
1:20:09
elements in dhpbn along some Dimension and to make the dimensions work out I
1:20:12
and to make the dimensions work out I
1:20:12
and to make the dimensions work out I need to Sum along the zeroth axis here
1:20:14
need to Sum along the zeroth axis here
1:20:14
need to Sum along the zeroth axis here to eliminate this Dimension and we do
1:20:17
to eliminate this Dimension and we do
1:20:17
to eliminate this Dimension and we do not keep dims
1:20:19
not keep dims
1:20:19
not keep dims uh so that we want to just get a single
1:20:21
uh so that we want to just get a single
1:20:21
uh so that we want to just get a single one-dimensional lecture of 64.
1:20:23
one-dimensional lecture of 64.
1:20:23
one-dimensional lecture of 64. so these are the claimed derivatives
1:20:27
so these are the claimed derivatives
1:20:27
so these are the claimed derivatives let me put that here and let me
1:20:29
let me put that here and let me
1:20:29
let me put that here and let me uncomment three lines and cross our
1:20:32
uncomment three lines and cross our
1:20:32
uncomment three lines and cross our fingers
1:20:34
fingers
1:20:34
fingers everything is great okay so we now
1:20:36
everything is great okay so we now
1:20:36
everything is great okay so we now continue almost there we have the
1:20:37
continue almost there we have the
1:20:37
continue almost there we have the derivative of MCAT and we want to
1:20:39
derivative of MCAT and we want to
1:20:39
derivative of MCAT and we want to derivative we want to back propagate
1:20:41
derivative we want to back propagate
1:20:41
derivative we want to back propagate into m
1:20:43
into m
1:20:43
into m so I again copied this line over here
1:20:46
so I again copied this line over here
1:20:46
so I again copied this line over here so this is the forward pass and then
1:20:48
so this is the forward pass and then
1:20:48
so this is the forward pass and then this is the shapes so remember that the
1:20:51
this is the shapes so remember that the
1:20:51
this is the shapes so remember that the shape here was 32 by 30 and the original
1:20:53
shape here was 32 by 30 and the original
1:20:53
shape here was 32 by 30 and the original shape of M plus 32 by 3 by 10. so this
1:20:57
shape of M plus 32 by 3 by 10. so this
1:20:57
shape of M plus 32 by 3 by 10. so this layer in the forward pass as you recall
1:20:58
layer in the forward pass as you recall
1:20:58
layer in the forward pass as you recall did the concatenation of these three
1:21:01
did the concatenation of these three
1:21:01
did the concatenation of these three 10-dimensional character vectors
1:21:04
10-dimensional character vectors
1:21:04
10-dimensional character vectors and so now we just want to undo that
1:21:06
and so now we just want to undo that
1:21:06
and so now we just want to undo that so this is actually relatively
1:21:08
so this is actually relatively
1:21:08
so this is actually relatively straightforward operation because uh the
1:21:11
straightforward operation because uh the
1:21:11
straightforward operation because uh the backward pass of the what is the view
1:21:12
backward pass of the what is the view
1:21:12
backward pass of the what is the view view is just a representation of the
1:21:15
view is just a representation of the
1:21:15
view is just a representation of the array it's just a logical form of how
1:21:17
array it's just a logical form of how
1:21:17
array it's just a logical form of how you interpret the array so let's just
1:21:18
you interpret the array so let's just
1:21:18
you interpret the array so let's just reinterpret it to be what it was before
1:21:21
reinterpret it to be what it was before
1:21:21
reinterpret it to be what it was before so in other words the end is not uh 32
1:21:25
so in other words the end is not uh 32
1:21:25
so in other words the end is not uh 32 by 30. it is basically dmcat
1:21:29
by 30. it is basically dmcat
1:21:29
by 30. it is basically dmcat but if you view it as the original shape
1:21:34
but if you view it as the original shape
1:21:34
but if you view it as the original shape so just m dot shape
1:21:37
so just m dot shape
1:21:37
so just m dot shape uh you can you can pass in tuples into
1:21:39
uh you can you can pass in tuples into
1:21:39
uh you can you can pass in tuples into view
1:21:40
view
1:21:40
view and so this should just be okay
1:21:44
we just re-represent that view and then
1:21:47
we just re-represent that view and then
1:21:47
we just re-represent that view and then we uncomment this line here and
1:21:49
we uncomment this line here and
1:21:49
we uncomment this line here and hopefully
1:21:51
hopefully
1:21:51
hopefully yeah so the derivative of M is correct
1:21:55
yeah so the derivative of M is correct
1:21:55
yeah so the derivative of M is correct so in this case we just have to
1:21:56
so in this case we just have to
1:21:56
so in this case we just have to re-represent the shape of those
1:21:57
re-represent the shape of those
1:21:57
re-represent the shape of those derivatives into the original View
1:21:59
derivatives into the original View
1:21:59
derivatives into the original View so now we are at the final line and the
1:22:01
so now we are at the final line and the
1:22:01
so now we are at the final line and the only thing that's left to back propagate
1:22:02
only thing that's left to back propagate
1:22:02
only thing that's left to back propagate through is this indexing operation here
1:22:05
through is this indexing operation here
1:22:05
through is this indexing operation here MSC at xB so as I did before I copy
1:22:09
MSC at xB so as I did before I copy
1:22:09
MSC at xB so as I did before I copy pasted this line here and let's look at
1:22:11
pasted this line here and let's look at
1:22:11
pasted this line here and let's look at the shapes of everything that's involved
1:22:12
the shapes of everything that's involved
1:22:12
the shapes of everything that's involved and remind ourselves how this worked
1:22:15
and remind ourselves how this worked
1:22:15
and remind ourselves how this worked so m.shape was 32 by 3 by 10.
1:22:19
so m.shape was 32 by 3 by 10.
1:22:19
so m.shape was 32 by 3 by 10. it says 32 examples and then we have
1:22:22
it says 32 examples and then we have
1:22:22
it says 32 examples and then we have three characters each one of them has a
1:22:24
three characters each one of them has a
1:22:24
three characters each one of them has a 10 dimensional embedding
1:22:26
10 dimensional embedding
1:22:26
10 dimensional embedding and this was achieved by taking the
1:22:28
and this was achieved by taking the
1:22:28
and this was achieved by taking the lookup table C which have 27 possible
1:22:31
lookup table C which have 27 possible
1:22:31
lookup table C which have 27 possible characters
1:22:32
characters
1:22:32
characters each of them 10 dimensional and we
1:22:34
each of them 10 dimensional and we
1:22:34
each of them 10 dimensional and we looked up
1:22:35
looked up
1:22:35
looked up at the rows that were specified inside
1:22:38
at the rows that were specified inside
1:22:39
at the rows that were specified inside this tensor xB
1:22:41
this tensor xB
1:22:41
this tensor xB so XB is 32 by 3 and it's basically
1:22:43
so XB is 32 by 3 and it's basically
1:22:43
so XB is 32 by 3 and it's basically giving us for each example the Identity
1:22:45
giving us for each example the Identity
1:22:45
giving us for each example the Identity or the index of which character is part
1:22:49
or the index of which character is part
1:22:49
or the index of which character is part of that example
1:22:50
of that example
1:22:50
of that example and so here I'm showing the first five
1:22:52
and so here I'm showing the first five
1:22:52
and so here I'm showing the first five rows of three of this tensor xB
1:22:57
rows of three of this tensor xB
1:22:57
rows of three of this tensor xB and so we can see that for example here
1:22:58
and so we can see that for example here
1:22:58
and so we can see that for example here it was the first example in this batch
1:23:00
it was the first example in this batch
1:23:00
it was the first example in this batch is that the first character and the
1:23:02
is that the first character and the
1:23:02
is that the first character and the first character and the fourth character
1:23:04
first character and the fourth character
1:23:04
first character and the fourth character comes into the neural net
1:23:06
comes into the neural net
1:23:06
comes into the neural net and then we want to predict the next
1:23:08
and then we want to predict the next
1:23:08
and then we want to predict the next character in a sequence after the
1:23:10
character in a sequence after the
1:23:10
character in a sequence after the character is one one four
1:23:12
character is one one four
1:23:12
character is one one four so basically What's Happening Here is
1:23:14
so basically What's Happening Here is
1:23:14
so basically What's Happening Here is there are integers inside XB and each
1:23:17
there are integers inside XB and each
1:23:18
there are integers inside XB and each one of these integers is specifying
1:23:19
one of these integers is specifying
1:23:19
one of these integers is specifying which row of C we want to pluck out
1:23:22
which row of C we want to pluck out
1:23:22
which row of C we want to pluck out right and then we arrange those rows
1:23:25
right and then we arrange those rows
1:23:25
right and then we arrange those rows that we've plucked out into 32 by 3 by
1:23:28
that we've plucked out into 32 by 3 by
1:23:28
that we've plucked out into 32 by 3 by 10 tensor and we just package them in we
1:23:30
10 tensor and we just package them in we
1:23:30
10 tensor and we just package them in we just package them into the sensor
1:23:33
just package them into the sensor
1:23:33
just package them into the sensor and now what's happening is that we have
1:23:35
and now what's happening is that we have
1:23:35
and now what's happening is that we have D amp
1:23:36
D amp
1:23:36
D amp so for every one of these uh basically
1:23:38
so for every one of these uh basically
1:23:39
so for every one of these uh basically plucked out rows we have their gradients
1:23:41
plucked out rows we have their gradients
1:23:41
plucked out rows we have their gradients now
1:23:42
now
1:23:42
now but they're arranged inside this 32 by 3
1:23:45
but they're arranged inside this 32 by 3
1:23:45
but they're arranged inside this 32 by 3 by 10 tensor so all we have to do now is
1:23:48
by 10 tensor so all we have to do now is
1:23:48
by 10 tensor so all we have to do now is we just need to Route this gradient
1:23:49
we just need to Route this gradient
1:23:49
we just need to Route this gradient backwards through this assignment so we
1:23:52
backwards through this assignment so we
1:23:52
backwards through this assignment so we need to find which row of C that every
1:23:54
need to find which row of C that every
1:23:54
need to find which row of C that every one of these
1:23:56
one of these
1:23:56
one of these um 10 dimensional embeddings come from
1:23:59
um 10 dimensional embeddings come from
1:23:59
um 10 dimensional embeddings come from and then we need to deposit them into DC
1:24:03
and then we need to deposit them into DC
1:24:03
and then we need to deposit them into DC so we just need to undo the indexing and
1:24:06
so we just need to undo the indexing and
1:24:06
so we just need to undo the indexing and of course if any of these rows of C was
1:24:08
of course if any of these rows of C was
1:24:08
of course if any of these rows of C was used multiple times which almost
1:24:10
used multiple times which almost
1:24:10
used multiple times which almost certainly is the case like the row one
1:24:11
certainly is the case like the row one
1:24:11
certainly is the case like the row one and one was used multiple times then we
1:24:13
and one was used multiple times then we
1:24:13
and one was used multiple times then we have to remember that the gradients that
1:24:15
have to remember that the gradients that
1:24:15
have to remember that the gradients that arrive there have to add
1:24:18
arrive there have to add
1:24:18
arrive there have to add so for each occurrence we have to have
1:24:19
so for each occurrence we have to have
1:24:19
so for each occurrence we have to have an addition
1:24:21
an addition
1:24:21
an addition so let's now write this out and I don't
1:24:23
so let's now write this out and I don't
1:24:23
so let's now write this out and I don't actually know if like a much better way
1:24:24
actually know if like a much better way
1:24:24
actually know if like a much better way to do this than a for Loop unfortunately
1:24:26
to do this than a for Loop unfortunately
1:24:26
to do this than a for Loop unfortunately in Python
1:24:28
in Python
1:24:28
in Python um so maybe someone can come up with a
1:24:29
um so maybe someone can come up with a
1:24:29
um so maybe someone can come up with a vectorized efficient operation but for
1:24:32
vectorized efficient operation but for
1:24:32
vectorized efficient operation but for now let's just use for loops so let me
1:24:34
now let's just use for loops so let me
1:24:34
now let's just use for loops so let me create a torch.zeros like
1:24:37
create a torch.zeros like
1:24:37
create a torch.zeros like C to initialize uh just uh 27 by 10
1:24:40
C to initialize uh just uh 27 by 10
1:24:40
C to initialize uh just uh 27 by 10 tensor of all zeros
1:24:43
tensor of all zeros
1:24:43
tensor of all zeros and then honestly 4K in range XB dot
1:24:46
and then honestly 4K in range XB dot
1:24:46
and then honestly 4K in range XB dot shape at zero
1:24:49
shape at zero
1:24:49
shape at zero maybe someone has a better way to do
1:24:51
maybe someone has a better way to do
1:24:51
maybe someone has a better way to do this but for J and range
1:24:53
this but for J and range
1:24:53
this but for J and range be that shape at one
1:24:55
be that shape at one
1:24:55
be that shape at one this is going to iterate over all the
1:24:58
this is going to iterate over all the
1:24:58
this is going to iterate over all the um all the elements of XB all these
1:25:01
um all the elements of XB all these
1:25:01
um all the elements of XB all these integers
1:25:03
integers
1:25:03
integers and then let's get the index at this
1:25:05
and then let's get the index at this
1:25:05
and then let's get the index at this position
1:25:06
position
1:25:06
position so the index is basically x b at KJ
1:25:11
so the index is basically x b at KJ
1:25:11
so the index is basically x b at KJ so that an example of that like is 11 or
1:25:14
so that an example of that like is 11 or
1:25:14
so that an example of that like is 11 or 14 and so on
1:25:16
14 and so on
1:25:16
14 and so on and now in the forward pass we took
1:25:19
and now in the forward pass we took
1:25:19
and now in the forward pass we took and we basically took um
1:25:24
the row of C at index and we deposited
1:25:27
the row of C at index and we deposited
1:25:27
the row of C at index and we deposited it into M at K of J
1:25:30
it into M at K of J
1:25:30
it into M at K of J that's what happened that's where they
1:25:32
that's what happened that's where they
1:25:32
that's what happened that's where they are packaged so now we need to go
1:25:34
are packaged so now we need to go
1:25:34
are packaged so now we need to go backwards and we just need to route
1:25:36
backwards and we just need to route
1:25:36
backwards and we just need to route DM at the position KJ
1:25:39
DM at the position KJ
1:25:39
DM at the position KJ we now have these derivatives
1:25:42
we now have these derivatives
1:25:42
we now have these derivatives for each position and it's 10
1:25:44
for each position and it's 10
1:25:44
for each position and it's 10 dimensional
1:25:45
dimensional
1:25:45
dimensional and you just need to go into the correct
1:25:47
and you just need to go into the correct
1:25:47
and you just need to go into the correct row of C
1:25:49
row of C
1:25:49
row of C so DC rather at IX is this but plus
1:25:54
so DC rather at IX is this but plus
1:25:54
so DC rather at IX is this but plus equals
1:25:55
equals
1:25:55
equals because there could be multiple
1:25:56
because there could be multiple
1:25:56
because there could be multiple occurrences uh like the same row could
1:25:58
occurrences uh like the same row could
1:25:58
occurrences uh like the same row could have been used many many times and so
1:26:00
have been used many many times and so
1:26:00
have been used many many times and so all of those derivatives will just go
1:26:04
all of those derivatives will just go
1:26:04
all of those derivatives will just go backwards through the indexing and they
1:26:06
backwards through the indexing and they
1:26:06
backwards through the indexing and they will add
1:26:07
will add
1:26:07
will add so this is my candidate solution
1:26:12
let's copy it here
1:26:16
let's uncomment this and cross our
1:26:19
let's uncomment this and cross our
1:26:19
let's uncomment this and cross our fingers
1:26:20
fingers
1:26:20
fingers hey
1:26:21
hey
1:26:21
hey so that's it we've back propagated
1:26:24
so that's it we've back propagated
1:26:24
so that's it we've back propagated through
1:26:25
through
1:26:25
through this entire Beast
1:26:28
this entire Beast
1:26:28
this entire Beast so there we go totally makes sense
1:26:31
so there we go totally makes sense
1:26:31
so there we go totally makes sense so now we come to exercise two it
1:26:33
so now we come to exercise two it
1:26:33
so now we come to exercise two it basically turns out that in this first
1:26:34
basically turns out that in this first
1:26:34
basically turns out that in this first exercise we were doing way too much work
1:26:36
exercise we were doing way too much work
1:26:36
exercise we were doing way too much work uh we were back propagating way too much
1:26:38
uh we were back propagating way too much
1:26:39
uh we were back propagating way too much and it was all good practice and so on
1:26:40
and it was all good practice and so on
1:26:40
and it was all good practice and so on but it's not what you would do in
1:26:42
but it's not what you would do in
1:26:42
but it's not what you would do in practice and the reason for that is for
1:26:44
practice and the reason for that is for
1:26:44
practice and the reason for that is for example here I separated out this loss
1:26:47
example here I separated out this loss
1:26:47
example here I separated out this loss calculation over multiple lines and I
1:26:49
calculation over multiple lines and I
1:26:49
calculation over multiple lines and I broke it up all all to like its smallest
1:26:51
broke it up all all to like its smallest
1:26:51
broke it up all all to like its smallest atomic pieces and we back propagated
1:26:53
atomic pieces and we back propagated
1:26:53
atomic pieces and we back propagated through all of those individually
1:26:55
through all of those individually
1:26:55
through all of those individually but it turns out that if you just look
1:26:56
but it turns out that if you just look
1:26:56
but it turns out that if you just look at the mathematical expression for the
1:26:58
at the mathematical expression for the
1:26:58
at the mathematical expression for the loss
1:26:59
loss
1:27:00
loss um then actually you can do the
1:27:02
um then actually you can do the
1:27:02
um then actually you can do the differentiation on pen and paper and a
1:27:04
differentiation on pen and paper and a
1:27:04
differentiation on pen and paper and a lot of terms cancel and simplify and the
1:27:06
lot of terms cancel and simplify and the
1:27:06
lot of terms cancel and simplify and the mathematical expression you end up with
1:27:07
mathematical expression you end up with
1:27:07
mathematical expression you end up with can be significantly shorter and easier
1:27:10
can be significantly shorter and easier
1:27:10
can be significantly shorter and easier to implement than back propagating
1:27:11
to implement than back propagating
1:27:11
to implement than back propagating through all the little pieces of
1:27:12
through all the little pieces of
1:27:12
through all the little pieces of everything you've done
1:27:13
everything you've done
1:27:13
everything you've done so before we had this complicated
1:27:16
so before we had this complicated
1:27:16
so before we had this complicated forward paths going from logits to the
1:27:18
forward paths going from logits to the
1:27:18
forward paths going from logits to the loss
1:27:19
loss
1:27:19
loss but in pytorch everything can just be
1:27:21
but in pytorch everything can just be
1:27:21
but in pytorch everything can just be glued together into a single call at
1:27:22
glued together into a single call at
1:27:22
glued together into a single call at that cross entropy you just pass in
1:27:24
that cross entropy you just pass in
1:27:24
that cross entropy you just pass in logits and the labels and you get the
1:27:26
logits and the labels and you get the
1:27:26
logits and the labels and you get the exact same loss as I verify here so our
1:27:28
exact same loss as I verify here so our
1:27:28
exact same loss as I verify here so our previous loss and the fast loss coming
1:27:31
previous loss and the fast loss coming
1:27:31
previous loss and the fast loss coming from the chunk of operations as a single
1:27:33
from the chunk of operations as a single
1:27:33
from the chunk of operations as a single mathematical expression is the same but
1:27:36
mathematical expression is the same but
1:27:36
mathematical expression is the same but it's much much faster in a forward pass
1:27:38
it's much much faster in a forward pass
1:27:38
it's much much faster in a forward pass it's also much much faster in backward
1:27:40
it's also much much faster in backward
1:27:40
it's also much much faster in backward pass and the reason for that is if you
1:27:42
pass and the reason for that is if you
1:27:42
pass and the reason for that is if you just look at the mathematical form of
1:27:43
just look at the mathematical form of
1:27:43
just look at the mathematical form of this and differentiate again you will
1:27:45
this and differentiate again you will
1:27:45
this and differentiate again you will end up with a very small and short
1:27:46
end up with a very small and short
1:27:46
end up with a very small and short expression so that's what we want to do
1:27:48
expression so that's what we want to do
1:27:48
expression so that's what we want to do here we want to in a single operation or
1:27:51
here we want to in a single operation or
1:27:51
here we want to in a single operation or in a single go or like very quickly go
1:27:54
in a single go or like very quickly go
1:27:54
in a single go or like very quickly go directly to delojits
1:27:56
directly to delojits
1:27:56
directly to delojits and we need to implement the logits as a
1:27:59
and we need to implement the logits as a
1:27:59
and we need to implement the logits as a function of logits and yb's
1:28:02
function of logits and yb's
1:28:02
function of logits and yb's but it will be significantly shorter
1:28:04
but it will be significantly shorter
1:28:04
but it will be significantly shorter than whatever we did here where to get
1:28:06
than whatever we did here where to get
1:28:06
than whatever we did here where to get to deluggets we had to go all the way
1:28:08
to deluggets we had to go all the way
1:28:08
to deluggets we had to go all the way here
1:28:10
here
1:28:10
here so all of this work can be skipped in a
1:28:12
so all of this work can be skipped in a
1:28:12
so all of this work can be skipped in a much much simpler mathematical
1:28:13
much much simpler mathematical
1:28:13
much much simpler mathematical expression that you can Implement here
1:28:16
expression that you can Implement here
1:28:16
expression that you can Implement here so you can give it a shot yourself
1:28:18
so you can give it a shot yourself
1:28:18
so you can give it a shot yourself basically look at what exactly is the
1:28:21
basically look at what exactly is the
1:28:21
basically look at what exactly is the mathematical expression of loss and
1:28:23
mathematical expression of loss and
1:28:23
mathematical expression of loss and differentiate with respect to the logits
1:28:26
differentiate with respect to the logits
1:28:26
differentiate with respect to the logits so let me show you a hint you can of
1:28:29
so let me show you a hint you can of
1:28:29
so let me show you a hint you can of course try it fully yourself but if not
1:28:31
course try it fully yourself but if not
1:28:31
course try it fully yourself but if not I can give you some hint of how to get
1:28:33
I can give you some hint of how to get
1:28:33
I can give you some hint of how to get started mathematically
1:28:36
so basically What's Happening Here is we
1:28:38
so basically What's Happening Here is we
1:28:38
so basically What's Happening Here is we have logits then there's a softmax that
1:28:41
have logits then there's a softmax that
1:28:41
have logits then there's a softmax that takes the logits and gives you
1:28:42
takes the logits and gives you
1:28:42
takes the logits and gives you probabilities then we are using the
1:28:44
probabilities then we are using the
1:28:44
probabilities then we are using the identity of the correct next character
1:28:46
identity of the correct next character
1:28:46
identity of the correct next character to pluck out a row of probabilities take
1:28:50
to pluck out a row of probabilities take
1:28:50
to pluck out a row of probabilities take the negative log of it to get our
1:28:51
the negative log of it to get our
1:28:51
the negative log of it to get our negative block probability and then we
1:28:54
negative block probability and then we
1:28:54
negative block probability and then we average up all the log probabilities or
1:28:56
average up all the log probabilities or
1:28:56
average up all the log probabilities or negative block probabilities to get our
1:28:58
negative block probabilities to get our
1:28:58
negative block probabilities to get our loss
1:28:59
loss
1:28:59
loss so basically what we have is for a
1:29:01
so basically what we have is for a
1:29:01
so basically what we have is for a single individual example rather we have
1:29:04
single individual example rather we have
1:29:04
single individual example rather we have that loss is equal to negative log
1:29:06
that loss is equal to negative log
1:29:06
that loss is equal to negative log probability uh where P here is kind of
1:29:09
probability uh where P here is kind of
1:29:09
probability uh where P here is kind of like thought of as a vector of all the
1:29:11
like thought of as a vector of all the
1:29:11
like thought of as a vector of all the probabilities so at the Y position where
1:29:14
probabilities so at the Y position where
1:29:14
probabilities so at the Y position where Y is the label
1:29:16
Y is the label
1:29:16
Y is the label and we have that P here of course is the
1:29:19
and we have that P here of course is the
1:29:19
and we have that P here of course is the softmax so the ith component of P of
1:29:23
softmax so the ith component of P of
1:29:23
softmax so the ith component of P of this probability Vector is just the
1:29:25
this probability Vector is just the
1:29:25
this probability Vector is just the softmax function so raising all the
1:29:28
softmax function so raising all the
1:29:28
softmax function so raising all the logits uh basically to the power of E
1:29:31
logits uh basically to the power of E
1:29:31
logits uh basically to the power of E and normalizing so everything comes to
1:29:34
and normalizing so everything comes to
1:29:34
and normalizing so everything comes to 1.
1:29:35
1.
1:29:35
1. now if you write out P of Y here you can
1:29:38
now if you write out P of Y here you can
1:29:38
now if you write out P of Y here you can just write out the soft Max and then
1:29:40
just write out the soft Max and then
1:29:40
just write out the soft Max and then basically what we're interested in is
1:29:41
basically what we're interested in is
1:29:41
basically what we're interested in is we're interested in the derivative of
1:29:43
we're interested in the derivative of
1:29:43
we're interested in the derivative of the loss with respect to the I logit
1:29:47
the loss with respect to the I logit
1:29:47
the loss with respect to the I logit and so basically it's a d by DLI of this
1:29:51
and so basically it's a d by DLI of this
1:29:51
and so basically it's a d by DLI of this expression here
1:29:52
expression here
1:29:52
expression here where we have L indexed with the
1:29:54
where we have L indexed with the
1:29:54
where we have L indexed with the specific label Y and on the bottom we
1:29:56
specific label Y and on the bottom we
1:29:56
specific label Y and on the bottom we have a sum over J of e to the L J and
1:29:58
have a sum over J of e to the L J and
1:29:58
have a sum over J of e to the L J and the negative block of all that so
1:30:00
the negative block of all that so
1:30:00
the negative block of all that so potentially give it a shot pen and paper
1:30:02
potentially give it a shot pen and paper
1:30:02
potentially give it a shot pen and paper and see if you can actually derive the
1:30:04
and see if you can actually derive the
1:30:04
and see if you can actually derive the expression for the loss by DLI and then
1:30:07
expression for the loss by DLI and then
1:30:07
expression for the loss by DLI and then we're going to implement it here okay so
1:30:09
we're going to implement it here okay so
1:30:09
we're going to implement it here okay so I'm going to give away the result here
1:30:11
I'm going to give away the result here
1:30:11
I'm going to give away the result here so this is some of the math I did to
1:30:13
so this is some of the math I did to
1:30:13
so this is some of the math I did to derive the gradients analytically and so
1:30:17
derive the gradients analytically and so
1:30:17
derive the gradients analytically and so we see here that I'm just applying the
1:30:19
we see here that I'm just applying the
1:30:19
we see here that I'm just applying the rules of calculus from your first or
1:30:20
rules of calculus from your first or
1:30:20
rules of calculus from your first or second year of bachelor's degree if you
1:30:22
second year of bachelor's degree if you
1:30:22
second year of bachelor's degree if you took it and we see that the expression
1:30:24
took it and we see that the expression
1:30:24
took it and we see that the expression is actually simplify quite a bit you
1:30:26
is actually simplify quite a bit you
1:30:26
is actually simplify quite a bit you have to separate out the analysis in the
1:30:27
have to separate out the analysis in the
1:30:27
have to separate out the analysis in the case where the ith index that you're
1:30:30
case where the ith index that you're
1:30:30
case where the ith index that you're interested in inside logits is either
1:30:32
interested in inside logits is either
1:30:32
interested in inside logits is either equal to the label or it's not equal to
1:30:34
equal to the label or it's not equal to
1:30:34
equal to the label or it's not equal to the label and then the expression
1:30:35
the label and then the expression
1:30:35
the label and then the expression simplify and cancel in a slightly
1:30:37
simplify and cancel in a slightly
1:30:37
simplify and cancel in a slightly different way and what we end up with is
1:30:39
different way and what we end up with is
1:30:39
different way and what we end up with is something very very simple
1:30:41
something very very simple
1:30:41
something very very simple and we either end up with basically
1:30:43
and we either end up with basically
1:30:43
and we either end up with basically pirai where p is again this Vector of
1:30:46
pirai where p is again this Vector of
1:30:46
pirai where p is again this Vector of probabilities after a soft Max or P at I
1:30:49
probabilities after a soft Max or P at I
1:30:49
probabilities after a soft Max or P at I minus 1 where we just simply subtract a
1:30:51
minus 1 where we just simply subtract a
1:30:51
minus 1 where we just simply subtract a one but in any case we just need to
1:30:53
one but in any case we just need to
1:30:53
one but in any case we just need to calculate the soft Max p e and then in
1:30:56
calculate the soft Max p e and then in
1:30:56
calculate the soft Max p e and then in the correct Dimension we need to
1:30:58
the correct Dimension we need to
1:30:58
the correct Dimension we need to subtract one and that's the gradient the
1:31:00
subtract one and that's the gradient the
1:31:00
subtract one and that's the gradient the form that it takes analytically so let's
1:31:02
form that it takes analytically so let's
1:31:03
form that it takes analytically so let's implement this basically and we have to
1:31:04
implement this basically and we have to
1:31:04
implement this basically and we have to keep in mind that this is only done for
1:31:05
keep in mind that this is only done for
1:31:06
keep in mind that this is only done for a single example but here we are working
1:31:08
a single example but here we are working
1:31:08
a single example but here we are working with batches of examples
1:31:09
with batches of examples
1:31:09
with batches of examples so we have to be careful of that and
1:31:12
so we have to be careful of that and
1:31:12
so we have to be careful of that and then the loss for a batch is the average
1:31:14
then the loss for a batch is the average
1:31:14
then the loss for a batch is the average loss over all the examples so in other
1:31:17
loss over all the examples so in other
1:31:17
loss over all the examples so in other words is the example for all the
1:31:18
words is the example for all the
1:31:18
words is the example for all the individual examples is the loss for each
1:31:20
individual examples is the loss for each
1:31:20
individual examples is the loss for each individual example summed up and then
1:31:22
individual example summed up and then
1:31:22
individual example summed up and then divided by n and we have to back
1:31:24
divided by n and we have to back
1:31:24
divided by n and we have to back propagate through that as well and be
1:31:26
propagate through that as well and be
1:31:26
propagate through that as well and be careful with it
1:31:28
careful with it
1:31:28
careful with it so deluggets is going to be of that soft
1:31:30
so deluggets is going to be of that soft
1:31:30
so deluggets is going to be of that soft Max
1:31:32
Max
1:31:32
Max uh pytorch has a softmax function that
1:31:35
uh pytorch has a softmax function that
1:31:35
uh pytorch has a softmax function that you can call and we want to apply the
1:31:36
you can call and we want to apply the
1:31:36
you can call and we want to apply the softmax on the logits and we want to go
1:31:39
softmax on the logits and we want to go
1:31:39
softmax on the logits and we want to go in the dimension that is one so
1:31:42
in the dimension that is one so
1:31:42
in the dimension that is one so basically we want to do the softmax
1:31:44
basically we want to do the softmax
1:31:44
basically we want to do the softmax along the rows of these logits
1:31:47
along the rows of these logits
1:31:47
along the rows of these logits then at the correct positions we need to
1:31:49
then at the correct positions we need to
1:31:49
then at the correct positions we need to subtract a 1. so delugits at iterating
1:31:52
subtract a 1. so delugits at iterating
1:31:52
subtract a 1. so delugits at iterating over all the rows
1:31:54
over all the rows
1:31:54
over all the rows and indexing into the columns
1:31:56
and indexing into the columns
1:31:57
and indexing into the columns provided by the correct labels inside YB
1:31:59
provided by the correct labels inside YB
1:32:00
provided by the correct labels inside YB we need to subtract one
1:32:03
we need to subtract one
1:32:03
we need to subtract one and then finally it's the average loss
1:32:05
and then finally it's the average loss
1:32:05
and then finally it's the average loss that is the loss and in the average
1:32:07
that is the loss and in the average
1:32:07
that is the loss and in the average there's a one over n of all the losses
1:32:09
there's a one over n of all the losses
1:32:09
there's a one over n of all the losses added up and so we need to also
1:32:11
added up and so we need to also
1:32:12
added up and so we need to also propagate through that division
1:32:14
propagate through that division
1:32:14
propagate through that division so the gradient has to be scaled down by
1:32:16
so the gradient has to be scaled down by
1:32:16
so the gradient has to be scaled down by by n as well because of the mean
1:32:19
by n as well because of the mean
1:32:19
by n as well because of the mean but this otherwise should be the result
1:32:22
but this otherwise should be the result
1:32:22
but this otherwise should be the result so now if we verify this
1:32:24
so now if we verify this
1:32:24
so now if we verify this we see that we don't get an exact match
1:32:26
we see that we don't get an exact match
1:32:26
we see that we don't get an exact match but at the same time the maximum
1:32:29
but at the same time the maximum
1:32:30
but at the same time the maximum difference from logits from pytorch and
1:32:33
difference from logits from pytorch and
1:32:33
difference from logits from pytorch and RD logits here is uh on the order of 5e
1:32:37
RD logits here is uh on the order of 5e
1:32:37
RD logits here is uh on the order of 5e negative 9. so it's a tiny tiny number
1:32:39
negative 9. so it's a tiny tiny number
1:32:39
negative 9. so it's a tiny tiny number so because of floating point wantiness
1:32:41
so because of floating point wantiness
1:32:41
so because of floating point wantiness we don't get the exact bitwise result
1:32:44
we don't get the exact bitwise result
1:32:44
we don't get the exact bitwise result but we basically get the correct answer
1:32:47
but we basically get the correct answer
1:32:47
but we basically get the correct answer approximately
1:32:49
approximately
1:32:49
approximately now I'd like to pause here briefly
1:32:51
now I'd like to pause here briefly
1:32:51
now I'd like to pause here briefly before we move on to the next exercise
1:32:52
before we move on to the next exercise
1:32:52
before we move on to the next exercise because I'd like us to get an intuitive
1:32:54
because I'd like us to get an intuitive
1:32:54
because I'd like us to get an intuitive sense of what the logits is because it
1:32:56
sense of what the logits is because it
1:32:56
sense of what the logits is because it has a beautiful and very simple
1:32:58
has a beautiful and very simple
1:32:58
has a beautiful and very simple explanation honestly
1:33:00
explanation honestly
1:33:00
explanation honestly um so here I'm taking the logits and I'm
1:33:03
um so here I'm taking the logits and I'm
1:33:03
um so here I'm taking the logits and I'm visualizing it and we can see that we
1:33:05
visualizing it and we can see that we
1:33:05
visualizing it and we can see that we have a batch of 32 examples of 27
1:33:07
have a batch of 32 examples of 27
1:33:07
have a batch of 32 examples of 27 characters
1:33:08
characters
1:33:08
characters and what is the logits intuitively right
1:33:10
and what is the logits intuitively right
1:33:10
and what is the logits intuitively right the logits is the probabilities that the
1:33:13
the logits is the probabilities that the
1:33:13
the logits is the probabilities that the properties Matrix in the forward pass
1:33:15
properties Matrix in the forward pass
1:33:15
properties Matrix in the forward pass but then here these black squares are
1:33:17
but then here these black squares are
1:33:17
but then here these black squares are the positions of the correct indices
1:33:19
the positions of the correct indices
1:33:19
the positions of the correct indices where we subtracted a one
1:33:21
where we subtracted a one
1:33:21
where we subtracted a one and so uh what is this doing right these
1:33:24
and so uh what is this doing right these
1:33:24
and so uh what is this doing right these are the derivatives on the logits and so
1:33:27
are the derivatives on the logits and so
1:33:27
are the derivatives on the logits and so let's look at just the first row here
1:33:31
let's look at just the first row here
1:33:31
let's look at just the first row here so that's what I'm doing here I'm
1:33:33
so that's what I'm doing here I'm
1:33:33
so that's what I'm doing here I'm clocking the probabilities of these
1:33:34
clocking the probabilities of these
1:33:34
clocking the probabilities of these logits and then I'm taking just the
1:33:36
logits and then I'm taking just the
1:33:36
logits and then I'm taking just the first row and this is the probability
1:33:38
first row and this is the probability
1:33:38
first row and this is the probability row and then the logits of the first row
1:33:41
row and then the logits of the first row
1:33:41
row and then the logits of the first row and multiplying by n just for us so that
1:33:43
and multiplying by n just for us so that
1:33:43
and multiplying by n just for us so that we don't have the scaling by n in here
1:33:46
we don't have the scaling by n in here
1:33:46
we don't have the scaling by n in here and everything is more interpretable we
1:33:48
and everything is more interpretable we
1:33:48
and everything is more interpretable we see that it's exactly equal to the
1:33:50
see that it's exactly equal to the
1:33:50
see that it's exactly equal to the probability of course but then the
1:33:52
probability of course but then the
1:33:52
probability of course but then the position of the correct index has a
1:33:53
position of the correct index has a
1:33:53
position of the correct index has a minus equals one so minus one on that
1:33:56
minus equals one so minus one on that
1:33:56
minus equals one so minus one on that position
1:33:57
position
1:33:57
position and so notice that
1:33:59
and so notice that
1:33:59
and so notice that um if you take Delo Jets at zero and you
1:34:01
um if you take Delo Jets at zero and you
1:34:01
um if you take Delo Jets at zero and you sum it
1:34:03
sum it
1:34:03
sum it it actually sums to zero and so you
1:34:06
it actually sums to zero and so you
1:34:06
it actually sums to zero and so you should think of these uh gradients here
1:34:08
should think of these uh gradients here
1:34:08
should think of these uh gradients here at each cell as like a force
1:34:12
at each cell as like a force
1:34:12
at each cell as like a force um we are going to be basically pulling
1:34:15
um we are going to be basically pulling
1:34:15
um we are going to be basically pulling down on the probabilities of the
1:34:17
down on the probabilities of the
1:34:17
down on the probabilities of the incorrect characters and we're going to
1:34:19
incorrect characters and we're going to
1:34:19
incorrect characters and we're going to be pulling up on the probability at the
1:34:22
be pulling up on the probability at the
1:34:22
be pulling up on the probability at the correct index and that's what's
1:34:24
correct index and that's what's
1:34:24
correct index and that's what's basically happening in each row and thus
1:34:29
basically happening in each row and thus
1:34:29
basically happening in each row and thus the amount of push and pull is exactly
1:34:31
the amount of push and pull is exactly
1:34:31
the amount of push and pull is exactly equalized because the sum is zero so the
1:34:34
equalized because the sum is zero so the
1:34:34
equalized because the sum is zero so the amount to which we pull down in the
1:34:36
amount to which we pull down in the
1:34:36
amount to which we pull down in the probabilities and the demand that we
1:34:37
probabilities and the demand that we
1:34:37
probabilities and the demand that we push up on the probability of the
1:34:39
push up on the probability of the
1:34:39
push up on the probability of the correct character is equal
1:34:41
correct character is equal
1:34:41
correct character is equal so sort of the the repulsion and the
1:34:43
so sort of the the repulsion and the
1:34:43
so sort of the the repulsion and the attraction are equal and think of the
1:34:45
attraction are equal and think of the
1:34:45
attraction are equal and think of the neural app now as a like a massive uh
1:34:48
neural app now as a like a massive uh
1:34:48
neural app now as a like a massive uh pulley system or something like that
1:34:50
pulley system or something like that
1:34:50
pulley system or something like that we're up here on top of the logits and
1:34:52
we're up here on top of the logits and
1:34:52
we're up here on top of the logits and we're pulling up we're pulling down the
1:34:54
we're pulling up we're pulling down the
1:34:54
we're pulling up we're pulling down the properties of Incorrect and pulling up
1:34:55
properties of Incorrect and pulling up
1:34:55
properties of Incorrect and pulling up the property of the correct and in this
1:34:57
the property of the correct and in this
1:34:57
the property of the correct and in this complicated pulley system because
1:34:59
complicated pulley system because
1:34:59
complicated pulley system because everything is mathematically uh just
1:35:01
everything is mathematically uh just
1:35:01
everything is mathematically uh just determined just think of it as sort of
1:35:03
determined just think of it as sort of
1:35:03
determined just think of it as sort of like this tension translating to this
1:35:05
like this tension translating to this
1:35:05
like this tension translating to this complicating pulling mechanism and then
1:35:07
complicating pulling mechanism and then
1:35:07
complicating pulling mechanism and then eventually we get a tug on the weights
1:35:09
eventually we get a tug on the weights
1:35:09
eventually we get a tug on the weights and the biases and basically in each
1:35:11
and the biases and basically in each
1:35:11
and the biases and basically in each update we just kind of like tug in the
1:35:13
update we just kind of like tug in the
1:35:13
update we just kind of like tug in the direction that we like for each of these
1:35:15
direction that we like for each of these
1:35:15
direction that we like for each of these elements and the parameters are slowly
1:35:17
elements and the parameters are slowly
1:35:17
elements and the parameters are slowly given in to the tug and that's what
1:35:19
given in to the tug and that's what
1:35:19
given in to the tug and that's what training in neural net kind of like
1:35:20
training in neural net kind of like
1:35:20
training in neural net kind of like looks like on a high level
1:35:22
looks like on a high level
1:35:22
looks like on a high level and so I think the the forces of push
1:35:24
and so I think the the forces of push
1:35:24
and so I think the the forces of push and pull in these gradients are actually
1:35:26
and pull in these gradients are actually
1:35:26
and pull in these gradients are actually uh very intuitive here we're pushing and
1:35:29
uh very intuitive here we're pushing and
1:35:29
uh very intuitive here we're pushing and pulling on the correct answer and the
1:35:31
pulling on the correct answer and the
1:35:31
pulling on the correct answer and the incorrect answers and the amount of
1:35:33
incorrect answers and the amount of
1:35:33
incorrect answers and the amount of force that we're applying is actually
1:35:34
force that we're applying is actually
1:35:34
force that we're applying is actually proportional to uh the probabilities
1:35:37
proportional to uh the probabilities
1:35:37
proportional to uh the probabilities that came out in the forward pass
1:35:39
that came out in the forward pass
1:35:39
that came out in the forward pass and so for example if our probabilities
1:35:41
and so for example if our probabilities
1:35:41
and so for example if our probabilities came out exactly correct so they would
1:35:43
came out exactly correct so they would
1:35:43
came out exactly correct so they would have had zero everywhere except for one
1:35:45
have had zero everywhere except for one
1:35:45
have had zero everywhere except for one at the correct uh position then the the
1:35:48
at the correct uh position then the the
1:35:48
at the correct uh position then the the logits would be all a row of zeros for
1:35:51
logits would be all a row of zeros for
1:35:51
logits would be all a row of zeros for that example there would be no push and
1:35:52
that example there would be no push and
1:35:52
that example there would be no push and pull so the amount to which your
1:35:55
pull so the amount to which your
1:35:55
pull so the amount to which your prediction is incorrect is exactly the
1:35:58
prediction is incorrect is exactly the
1:35:58
prediction is incorrect is exactly the amount by which you're going to get a
1:35:59
amount by which you're going to get a
1:35:59
amount by which you're going to get a pull or a push in that dimension
1:36:01
pull or a push in that dimension
1:36:01
pull or a push in that dimension so if you have for example a very
1:36:04
so if you have for example a very
1:36:04
so if you have for example a very confidently mispredicted element here
1:36:05
confidently mispredicted element here
1:36:05
confidently mispredicted element here then
1:36:07
then
1:36:07
then um what's going to happen is that
1:36:08
um what's going to happen is that
1:36:08
um what's going to happen is that element is going to be pulled down very
1:36:10
element is going to be pulled down very
1:36:10
element is going to be pulled down very heavily and the correct answer is going
1:36:12
heavily and the correct answer is going
1:36:12
heavily and the correct answer is going to be pulled up to the same amount
1:36:14
to be pulled up to the same amount
1:36:14
to be pulled up to the same amount and the other characters are not going
1:36:16
and the other characters are not going
1:36:16
and the other characters are not going to be influenced too much
1:36:19
to be influenced too much
1:36:19
to be influenced too much so the amounts to which you mispredict
1:36:21
so the amounts to which you mispredict
1:36:21
so the amounts to which you mispredict is then proportional to the strength of
1:36:23
is then proportional to the strength of
1:36:23
is then proportional to the strength of the pole and that's happening
1:36:25
the pole and that's happening
1:36:25
the pole and that's happening independently in all the dimensions of
1:36:27
independently in all the dimensions of
1:36:27
independently in all the dimensions of this of this tensor and it's sort of
1:36:29
this of this tensor and it's sort of
1:36:29
this of this tensor and it's sort of very intuitive and varies to think
1:36:30
very intuitive and varies to think
1:36:30
very intuitive and varies to think through and that's basically the magic
1:36:32
through and that's basically the magic
1:36:32
through and that's basically the magic of the cross-entropy loss and what it's
1:36:34
of the cross-entropy loss and what it's
1:36:34
of the cross-entropy loss and what it's doing dynamically in the backward pass
1:36:36
doing dynamically in the backward pass
1:36:36
doing dynamically in the backward pass of the neural net so now we get to
1:36:38
of the neural net so now we get to
1:36:38
of the neural net so now we get to exercise number three which is a very
1:36:41
exercise number three which is a very
1:36:41
exercise number three which is a very fun exercise
1:36:42
fun exercise
1:36:42
fun exercise um depending on your definition of fun
1:36:43
um depending on your definition of fun
1:36:43
um depending on your definition of fun and we are going to do for batch
1:36:45
and we are going to do for batch
1:36:45
and we are going to do for batch normalization exactly what we did for
1:36:47
normalization exactly what we did for
1:36:47
normalization exactly what we did for cross entropy loss in exercise number
1:36:49
cross entropy loss in exercise number
1:36:49
cross entropy loss in exercise number two that is we are going to consider it
1:36:51
two that is we are going to consider it
1:36:51
two that is we are going to consider it as a glued single mathematical
1:36:52
as a glued single mathematical
1:36:52
as a glued single mathematical expression and back propagate through it
1:36:54
expression and back propagate through it
1:36:54
expression and back propagate through it in a very efficient manner because we
1:36:56
in a very efficient manner because we
1:36:56
in a very efficient manner because we are going to derive a much simpler
1:36:58
are going to derive a much simpler
1:36:58
are going to derive a much simpler formula for the backward path of batch
1:36:59
formula for the backward path of batch
1:36:59
formula for the backward path of batch normalization
1:37:01
normalization
1:37:01
normalization and we're going to do that using pen and
1:37:02
and we're going to do that using pen and
1:37:02
and we're going to do that using pen and paper
1:37:03
paper
1:37:03
paper so previously we've broken up
1:37:05
so previously we've broken up
1:37:05
so previously we've broken up bastionalization into all of the little
1:37:06
bastionalization into all of the little
1:37:06
bastionalization into all of the little intermediate pieces and all the atomic
1:37:08
intermediate pieces and all the atomic
1:37:08
intermediate pieces and all the atomic operations inside it and then we back
1:37:10
operations inside it and then we back
1:37:10
operations inside it and then we back propagate it through it one by one
1:37:13
propagate it through it one by one
1:37:13
propagate it through it one by one now we just have a single sort of
1:37:15
now we just have a single sort of
1:37:15
now we just have a single sort of forward pass of a batch form and it's
1:37:18
forward pass of a batch form and it's
1:37:18
forward pass of a batch form and it's all glued together
1:37:20
all glued together
1:37:20
all glued together and we see that we get the exact same
1:37:21
and we see that we get the exact same
1:37:21
and we see that we get the exact same result as before
1:37:23
result as before
1:37:23
result as before now for the backward pass we'd like to
1:37:25
now for the backward pass we'd like to
1:37:25
now for the backward pass we'd like to also Implement a single formula
1:37:27
also Implement a single formula
1:37:27
also Implement a single formula basically for back propagating through
1:37:29
basically for back propagating through
1:37:29
basically for back propagating through this entire operation that is the
1:37:30
this entire operation that is the
1:37:30
this entire operation that is the bachelorization
1:37:32
bachelorization
1:37:32
bachelorization so in the forward pass previously we
1:37:34
so in the forward pass previously we
1:37:34
so in the forward pass previously we took hpvn the hidden states of the
1:37:37
took hpvn the hidden states of the
1:37:37
took hpvn the hidden states of the pre-batch realization and created H
1:37:39
pre-batch realization and created H
1:37:39
pre-batch realization and created H preact which is the hidden States just
1:37:42
preact which is the hidden States just
1:37:42
preact which is the hidden States just before the activation
1:37:44
before the activation
1:37:44
before the activation in the bachelorization paper each pbn is
1:37:46
in the bachelorization paper each pbn is
1:37:46
in the bachelorization paper each pbn is X and each preact is y
1:37:49
X and each preact is y
1:37:49
X and each preact is y so in the backward pass what we'd like
1:37:51
so in the backward pass what we'd like
1:37:51
so in the backward pass what we'd like to do now is we have DH preact and we'd
1:37:54
to do now is we have DH preact and we'd
1:37:54
to do now is we have DH preact and we'd like to produce d h previous
1:37:56
like to produce d h previous
1:37:56
like to produce d h previous and we'd like to do that in a very
1:37:57
and we'd like to do that in a very
1:37:57
and we'd like to do that in a very efficient manner so that's the name of
1:38:00
efficient manner so that's the name of
1:38:00
efficient manner so that's the name of the game calculate the H previan given
1:38:02
the game calculate the H previan given
1:38:02
the game calculate the H previan given DH preact and for the purposes of this
1:38:05
DH preact and for the purposes of this
1:38:05
DH preact and for the purposes of this exercise we're going to ignore gamma and
1:38:07
exercise we're going to ignore gamma and
1:38:07
exercise we're going to ignore gamma and beta and their derivatives because they
1:38:09
beta and their derivatives because they
1:38:09
beta and their derivatives because they take on a very simple form in a very
1:38:11
take on a very simple form in a very
1:38:11
take on a very simple form in a very similar way to what we did up above
1:38:14
similar way to what we did up above
1:38:14
similar way to what we did up above so let's calculate this given that right
1:38:17
so let's calculate this given that right
1:38:18
so let's calculate this given that right here
1:38:18
here
1:38:18
here so to help you a little bit like I did
1:38:20
so to help you a little bit like I did
1:38:20
so to help you a little bit like I did before I started off the implementation
1:38:23
before I started off the implementation
1:38:23
before I started off the implementation here on pen and paper and I took two
1:38:26
here on pen and paper and I took two
1:38:26
here on pen and paper and I took two sheets of paper to derive the
1:38:28
sheets of paper to derive the
1:38:28
sheets of paper to derive the mathematical formulas for the backward
1:38:29
mathematical formulas for the backward
1:38:29
mathematical formulas for the backward pass
1:38:30
pass
1:38:30
pass and basically to set up the problem uh
1:38:33
and basically to set up the problem uh
1:38:33
and basically to set up the problem uh just write out the MU Sigma Square
1:38:35
just write out the MU Sigma Square
1:38:35
just write out the MU Sigma Square variance x i hat and Y I exactly as in
1:38:39
variance x i hat and Y I exactly as in
1:38:39
variance x i hat and Y I exactly as in the paper except for the bezel
1:38:40
the paper except for the bezel
1:38:40
the paper except for the bezel correction
1:38:41
correction
1:38:41
correction and then
1:38:42
and then
1:38:42
and then in a backward pass we have the
1:38:44
in a backward pass we have the
1:38:44
in a backward pass we have the derivative of the loss with respect to
1:38:46
derivative of the loss with respect to
1:38:46
derivative of the loss with respect to all the elements of Y and remember that
1:38:48
all the elements of Y and remember that
1:38:48
all the elements of Y and remember that Y is a vector there's there's multiple
1:38:50
Y is a vector there's there's multiple
1:38:50
Y is a vector there's there's multiple numbers here
1:38:52
numbers here
1:38:52
numbers here so we have all the derivatives with
1:38:54
so we have all the derivatives with
1:38:54
so we have all the derivatives with respect to all the Y's
1:38:56
respect to all the Y's
1:38:56
respect to all the Y's and then there's a demo and a beta and
1:38:59
and then there's a demo and a beta and
1:38:59
and then there's a demo and a beta and this is kind of like the compute graph
1:39:01
this is kind of like the compute graph
1:39:01
this is kind of like the compute graph the gamma and the beta there's the X hat
1:39:03
the gamma and the beta there's the X hat
1:39:03
the gamma and the beta there's the X hat and then the MU and the sigma squared
1:39:06
and then the MU and the sigma squared
1:39:06
and then the MU and the sigma squared and the X so we have DL by DYI and we
1:39:10
and the X so we have DL by DYI and we
1:39:10
and the X so we have DL by DYI and we won't DL by d x i for all the I's in
1:39:13
won't DL by d x i for all the I's in
1:39:13
won't DL by d x i for all the I's in these vectors
1:39:15
these vectors
1:39:15
these vectors so this is the compute graph and you
1:39:17
so this is the compute graph and you
1:39:17
so this is the compute graph and you have to be careful because I'm trying to
1:39:19
have to be careful because I'm trying to
1:39:19
have to be careful because I'm trying to note here that these are vectors so
1:39:22
note here that these are vectors so
1:39:22
note here that these are vectors so there's many nodes here inside x x hat
1:39:25
there's many nodes here inside x x hat
1:39:25
there's many nodes here inside x x hat and Y but mu and sigma sorry Sigma
1:39:29
and Y but mu and sigma sorry Sigma
1:39:29
and Y but mu and sigma sorry Sigma Square are just individual scalars
1:39:30
Square are just individual scalars
1:39:30
Square are just individual scalars single numbers so you have to be careful
1:39:33
single numbers so you have to be careful
1:39:33
single numbers so you have to be careful with that you have to imagine there's
1:39:34
with that you have to imagine there's
1:39:34
with that you have to imagine there's multiple nodes here or you're going to
1:39:35
multiple nodes here or you're going to
1:39:35
multiple nodes here or you're going to get your math wrong
1:39:38
get your math wrong
1:39:38
get your math wrong um so as an example I would suggest that
1:39:40
um so as an example I would suggest that
1:39:40
um so as an example I would suggest that you go in the following order one two
1:39:43
you go in the following order one two
1:39:43
you go in the following order one two three four in terms of the back
1:39:44
three four in terms of the back
1:39:44
three four in terms of the back propagation so back propagating to X hat
1:39:46
propagation so back propagating to X hat
1:39:46
propagation so back propagating to X hat then into Sigma Square then into mu and
1:39:49
then into Sigma Square then into mu and
1:39:49
then into Sigma Square then into mu and then into X
1:39:52
then into X
1:39:52
then into X um just like in a topological sort in
1:39:54
um just like in a topological sort in
1:39:54
um just like in a topological sort in micrograd we would go from right to left
1:39:55
micrograd we would go from right to left
1:39:55
micrograd we would go from right to left you're doing the exact same thing except
1:39:57
you're doing the exact same thing except
1:39:57
you're doing the exact same thing except you're doing it with symbols and on a
1:39:59
you're doing it with symbols and on a
1:39:59
you're doing it with symbols and on a piece of paper
1:40:01
piece of paper
1:40:01
piece of paper so for number one uh I'm not giving away
1:40:05
so for number one uh I'm not giving away
1:40:05
so for number one uh I'm not giving away too much if you want DL of d x i hat
1:40:09
too much if you want DL of d x i hat
1:40:09
too much if you want DL of d x i hat then we just take DL by DYI and multiply
1:40:12
then we just take DL by DYI and multiply
1:40:12
then we just take DL by DYI and multiply it by gamma because of this expression
1:40:15
it by gamma because of this expression
1:40:15
it by gamma because of this expression here where any individual Yi is just
1:40:17
here where any individual Yi is just
1:40:17
here where any individual Yi is just gamma times x i hat plus beta so it
1:40:21
gamma times x i hat plus beta so it
1:40:21
gamma times x i hat plus beta so it doesn't help you too much there but this
1:40:23
doesn't help you too much there but this
1:40:23
doesn't help you too much there but this gives you basically the derivatives for
1:40:25
gives you basically the derivatives for
1:40:25
gives you basically the derivatives for all the X hats and so now try to go
1:40:28
all the X hats and so now try to go
1:40:28
all the X hats and so now try to go through this computational graph and
1:40:31
through this computational graph and
1:40:31
through this computational graph and derive what is DL by D Sigma Square
1:40:35
derive what is DL by D Sigma Square
1:40:35
derive what is DL by D Sigma Square and then what is DL by B mu and then one
1:40:38
and then what is DL by B mu and then one
1:40:38
and then what is DL by B mu and then one is D L by DX
1:40:39
is D L by DX
1:40:39
is D L by DX eventually so give it a go and I'm going
1:40:42
eventually so give it a go and I'm going
1:40:42
eventually so give it a go and I'm going to be revealing the answer one piece at
1:40:44
to be revealing the answer one piece at
1:40:44
to be revealing the answer one piece at a time okay so to get DL by D Sigma
1:40:46
a time okay so to get DL by D Sigma
1:40:46
a time okay so to get DL by D Sigma Square we have to remember again like I
1:40:48
Square we have to remember again like I
1:40:48
Square we have to remember again like I mentioned that there are many excess X
1:40:51
mentioned that there are many excess X
1:40:51
mentioned that there are many excess X hats here
1:40:52
hats here
1:40:52
hats here and remember that Sigma square is just a
1:40:54
and remember that Sigma square is just a
1:40:54
and remember that Sigma square is just a single individual number here
1:40:55
single individual number here
1:40:55
single individual number here so when we look at the expression
1:40:59
so when we look at the expression
1:40:59
so when we look at the expression for the L by D Sigma Square
1:41:01
for the L by D Sigma Square
1:41:01
for the L by D Sigma Square we have that we have to actually
1:41:03
we have that we have to actually
1:41:03
we have that we have to actually consider all the possible paths that um
1:41:08
consider all the possible paths that um
1:41:08
consider all the possible paths that um we basically have that there's many X
1:41:10
we basically have that there's many X
1:41:10
we basically have that there's many X hats and they all feed off from they all
1:41:13
hats and they all feed off from they all
1:41:13
hats and they all feed off from they all depend on Sigma Square so Sigma square
1:41:15
depend on Sigma Square so Sigma square
1:41:15
depend on Sigma Square so Sigma square has a large fan out there's lots of
1:41:17
has a large fan out there's lots of
1:41:17
has a large fan out there's lots of arrows coming out from Sigma square into
1:41:19
arrows coming out from Sigma square into
1:41:19
arrows coming out from Sigma square into all the X hats
1:41:20
all the X hats
1:41:20
all the X hats and then there's a back propagating
1:41:22
and then there's a back propagating
1:41:22
and then there's a back propagating signal from each X hat into Sigma square
1:41:24
signal from each X hat into Sigma square
1:41:24
signal from each X hat into Sigma square and that's why we actually need to sum
1:41:26
and that's why we actually need to sum
1:41:26
and that's why we actually need to sum over all those I's from I equal to 1 to
1:41:29
over all those I's from I equal to 1 to
1:41:29
over all those I's from I equal to 1 to m
1:41:30
m
1:41:30
m of the DL by d x i hat which is the
1:41:35
of the DL by d x i hat which is the
1:41:35
of the DL by d x i hat which is the global gradient
1:41:36
global gradient
1:41:36
global gradient times the x i Hat by D Sigma Square
1:41:40
times the x i Hat by D Sigma Square
1:41:40
times the x i Hat by D Sigma Square which is the local gradient
1:41:42
which is the local gradient
1:41:42
which is the local gradient of this operation here
1:41:44
of this operation here
1:41:44
of this operation here and then mathematically I'm just working
1:41:46
and then mathematically I'm just working
1:41:46
and then mathematically I'm just working it out here and I'm simplifying and you
1:41:48
it out here and I'm simplifying and you
1:41:48
it out here and I'm simplifying and you get a certain expression for DL by D
1:41:50
get a certain expression for DL by D
1:41:51
get a certain expression for DL by D Sigma square and we're going to be using
1:41:52
Sigma square and we're going to be using
1:41:52
Sigma square and we're going to be using this expression when we back propagate
1:41:53
this expression when we back propagate
1:41:53
this expression when we back propagate into mu and then eventually into X so
1:41:56
into mu and then eventually into X so
1:41:56
into mu and then eventually into X so now let's continue our back propagation
1:41:58
now let's continue our back propagation
1:41:58
now let's continue our back propagation into mu so what is D L by D mu now again
1:42:01
into mu so what is D L by D mu now again
1:42:01
into mu so what is D L by D mu now again be careful that mu influences X hat and
1:42:04
be careful that mu influences X hat and
1:42:04
be careful that mu influences X hat and X hat is actually lots of values so for
1:42:07
X hat is actually lots of values so for
1:42:07
X hat is actually lots of values so for example if our mini batch size is 32 as
1:42:09
example if our mini batch size is 32 as
1:42:09
example if our mini batch size is 32 as it is in our example that we were
1:42:10
it is in our example that we were
1:42:10
it is in our example that we were working on then this is 32 numbers and
1:42:13
working on then this is 32 numbers and
1:42:13
working on then this is 32 numbers and 32 arrows going back to mu and then mu
1:42:16
32 arrows going back to mu and then mu
1:42:16
32 arrows going back to mu and then mu going to Sigma square is just a single
1:42:18
going to Sigma square is just a single
1:42:18
going to Sigma square is just a single Arrow because Sigma square is a scalar
1:42:19
Arrow because Sigma square is a scalar
1:42:19
Arrow because Sigma square is a scalar so in total there are 33 arrows
1:42:22
so in total there are 33 arrows
1:42:22
so in total there are 33 arrows emanating from you and then all of them
1:42:25
emanating from you and then all of them
1:42:25
emanating from you and then all of them have gradients coming into mu and they
1:42:27
have gradients coming into mu and they
1:42:27
have gradients coming into mu and they all need to be summed up
1:42:29
all need to be summed up
1:42:29
all need to be summed up and so that's why when we look at the
1:42:31
and so that's why when we look at the
1:42:31
and so that's why when we look at the expression for DL by D mu I am summing
1:42:34
expression for DL by D mu I am summing
1:42:34
expression for DL by D mu I am summing up over all the gradients of DL by d x i
1:42:37
up over all the gradients of DL by d x i
1:42:37
up over all the gradients of DL by d x i hat times the x i Hat by being mu
1:42:40
hat times the x i Hat by being mu
1:42:40
hat times the x i Hat by being mu uh so that's the that's this arrow and
1:42:43
uh so that's the that's this arrow and
1:42:43
uh so that's the that's this arrow and that's 32 arrows here and then plus the
1:42:45
that's 32 arrows here and then plus the
1:42:45
that's 32 arrows here and then plus the one Arrow from here which is the L by
1:42:47
one Arrow from here which is the L by
1:42:47
one Arrow from here which is the L by the sigma Square Times the sigma squared
1:42:49
the sigma Square Times the sigma squared
1:42:49
the sigma Square Times the sigma squared by D mu
1:42:50
by D mu
1:42:50
by D mu so now we have to work out that
1:42:52
so now we have to work out that
1:42:52
so now we have to work out that expression and let me just reveal the
1:42:53
expression and let me just reveal the
1:42:54
expression and let me just reveal the rest of it
1:42:55
rest of it
1:42:55
rest of it uh simplifying here is not complicated
1:42:58
uh simplifying here is not complicated
1:42:58
uh simplifying here is not complicated the first term and you just get an
1:43:00
the first term and you just get an
1:43:00
the first term and you just get an expression here
1:43:01
expression here
1:43:01
expression here for the second term though there's
1:43:02
for the second term though there's
1:43:02
for the second term though there's something really interesting that
1:43:03
something really interesting that
1:43:03
something really interesting that happens
1:43:04
happens
1:43:04
happens when we look at the sigma squared by D
1:43:06
when we look at the sigma squared by D
1:43:06
when we look at the sigma squared by D mu and we simplify
1:43:08
mu and we simplify
1:43:08
mu and we simplify at one point if we assume that in a
1:43:11
at one point if we assume that in a
1:43:11
at one point if we assume that in a special case where mu is actually the
1:43:14
special case where mu is actually the
1:43:14
special case where mu is actually the average of X I's as it is in this case
1:43:17
average of X I's as it is in this case
1:43:17
average of X I's as it is in this case then if we plug that in then actually
1:43:20
then if we plug that in then actually
1:43:20
then if we plug that in then actually the gradient vanishes and becomes
1:43:22
the gradient vanishes and becomes
1:43:22
the gradient vanishes and becomes exactly zero and that makes the entire
1:43:24
exactly zero and that makes the entire
1:43:24
exactly zero and that makes the entire second term cancel
1:43:26
second term cancel
1:43:26
second term cancel and so these uh if you just have a
1:43:29
and so these uh if you just have a
1:43:29
and so these uh if you just have a mathematical expression like this and
1:43:30
mathematical expression like this and
1:43:30
mathematical expression like this and you look at D Sigma Square by D mu you
1:43:32
you look at D Sigma Square by D mu you
1:43:33
you look at D Sigma Square by D mu you would get some mathematical formula for
1:43:35
would get some mathematical formula for
1:43:35
would get some mathematical formula for how mu impacts Sigma Square
1:43:37
how mu impacts Sigma Square
1:43:37
how mu impacts Sigma Square but if it is the special case that Nu is
1:43:39
but if it is the special case that Nu is
1:43:39
but if it is the special case that Nu is actually equal to the average as it is
1:43:41
actually equal to the average as it is
1:43:42
actually equal to the average as it is in the case of pastoralization that
1:43:43
in the case of pastoralization that
1:43:43
in the case of pastoralization that gradient will actually vanish and become
1:43:45
gradient will actually vanish and become
1:43:45
gradient will actually vanish and become zero so the whole term cancels and we
1:43:47
zero so the whole term cancels and we
1:43:48
zero so the whole term cancels and we just get a fairly straightforward
1:43:49
just get a fairly straightforward
1:43:49
just get a fairly straightforward expression here for DL by D mu okay and
1:43:52
expression here for DL by D mu okay and
1:43:52
expression here for DL by D mu okay and now we get to the craziest part which is
1:43:54
now we get to the craziest part which is
1:43:54
now we get to the craziest part which is uh deriving DL by dxi which is
1:43:57
uh deriving DL by dxi which is
1:43:57
uh deriving DL by dxi which is ultimately what we're after
1:43:59
ultimately what we're after
1:43:59
ultimately what we're after now let's count
1:44:00
now let's count
1:44:00
now let's count first of all how many numbers are there
1:44:03
first of all how many numbers are there
1:44:03
first of all how many numbers are there inside X as I mentioned there are 32
1:44:05
inside X as I mentioned there are 32
1:44:05
inside X as I mentioned there are 32 numbers there are 32 Little X I's and
1:44:08
numbers there are 32 Little X I's and
1:44:08
numbers there are 32 Little X I's and let's count the number of arrows
1:44:09
let's count the number of arrows
1:44:09
let's count the number of arrows emanating from each x i
1:44:11
emanating from each x i
1:44:11
emanating from each x i there's an arrow going to Mu an arrow
1:44:13
there's an arrow going to Mu an arrow
1:44:13
there's an arrow going to Mu an arrow going to Sigma Square
1:44:14
going to Sigma Square
1:44:14
going to Sigma Square and then there's an arrow going to X hat
1:44:16
and then there's an arrow going to X hat
1:44:16
and then there's an arrow going to X hat but this Arrow here let's scrutinize
1:44:19
but this Arrow here let's scrutinize
1:44:19
but this Arrow here let's scrutinize that a little bit
1:44:20
that a little bit
1:44:20
that a little bit each x i hat is just a function of x i
1:44:23
each x i hat is just a function of x i
1:44:23
each x i hat is just a function of x i and all the other scalars so x i hat
1:44:27
and all the other scalars so x i hat
1:44:27
and all the other scalars so x i hat only depends on x i and none of the
1:44:29
only depends on x i and none of the
1:44:29
only depends on x i and none of the other X's
1:44:30
other X's
1:44:30
other X's and so therefore there are actually in
1:44:32
and so therefore there are actually in
1:44:32
and so therefore there are actually in this single Arrow there are 32 arrows
1:44:34
this single Arrow there are 32 arrows
1:44:34
this single Arrow there are 32 arrows but those 32 arrows are going exactly
1:44:37
but those 32 arrows are going exactly
1:44:37
but those 32 arrows are going exactly parallel they don't interfere and
1:44:38
parallel they don't interfere and
1:44:39
parallel they don't interfere and they're just going parallel between x
1:44:40
they're just going parallel between x
1:44:40
they're just going parallel between x and x hat you can look at it that way
1:44:42
and x hat you can look at it that way
1:44:42
and x hat you can look at it that way and so how many arrows are emanating
1:44:44
and so how many arrows are emanating
1:44:44
and so how many arrows are emanating from each x i there are three arrows mu
1:44:47
from each x i there are three arrows mu
1:44:47
from each x i there are three arrows mu Sigma squared and the associated X hat
1:44:50
Sigma squared and the associated X hat
1:44:50
Sigma squared and the associated X hat and so in back propagation we now need
1:44:53
and so in back propagation we now need
1:44:53
and so in back propagation we now need to apply the chain rule and we need to
1:44:55
to apply the chain rule and we need to
1:44:55
to apply the chain rule and we need to add up those three contributions
1:44:57
add up those three contributions
1:44:57
add up those three contributions so here's what that looks like if I just
1:44:59
so here's what that looks like if I just
1:44:59
so here's what that looks like if I just write that out
1:45:02
write that out
1:45:02
write that out we have uh we're going through we're
1:45:04
we have uh we're going through we're
1:45:04
we have uh we're going through we're chaining through mu Sigma square and
1:45:06
chaining through mu Sigma square and
1:45:06
chaining through mu Sigma square and through X hat and those three terms are
1:45:09
through X hat and those three terms are
1:45:09
through X hat and those three terms are just here
1:45:10
just here
1:45:10
just here now we already have three of these we
1:45:13
now we already have three of these we
1:45:13
now we already have three of these we have d l by d x i hat
1:45:15
have d l by d x i hat
1:45:15
have d l by d x i hat we have DL by D mu which we derived here
1:45:17
we have DL by D mu which we derived here
1:45:17
we have DL by D mu which we derived here and we have DL by D Sigma Square which
1:45:19
and we have DL by D Sigma Square which
1:45:19
and we have DL by D Sigma Square which we derived here but we need three other
1:45:22
we derived here but we need three other
1:45:22
we derived here but we need three other terms here
1:45:23
terms here
1:45:23
terms here the this one this one and this one so I
1:45:26
the this one this one and this one so I
1:45:26
the this one this one and this one so I invite you to try to derive them it's
1:45:28
invite you to try to derive them it's
1:45:28
invite you to try to derive them it's not that complicated you're just looking
1:45:29
not that complicated you're just looking
1:45:29
not that complicated you're just looking at these Expressions here and
1:45:31
at these Expressions here and
1:45:31
at these Expressions here and differentiating with respect to x i
1:45:34
differentiating with respect to x i
1:45:34
differentiating with respect to x i so give it a shot but here's the result
1:45:39
or at least what I got
1:45:41
or at least what I got
1:45:41
or at least what I got um
1:45:42
um
1:45:42
um yeah I'm just I'm just differentiating
1:45:44
yeah I'm just I'm just differentiating
1:45:44
yeah I'm just I'm just differentiating with respect to x i for all these
1:45:45
with respect to x i for all these
1:45:45
with respect to x i for all these expressions and honestly I don't think
1:45:47
expressions and honestly I don't think
1:45:47
expressions and honestly I don't think there's anything too tricky here it's
1:45:48
there's anything too tricky here it's
1:45:48
there's anything too tricky here it's basic calculus
1:45:50
basic calculus
1:45:50
basic calculus now it gets a little bit more tricky is
1:45:52
now it gets a little bit more tricky is
1:45:52
now it gets a little bit more tricky is we are now going to plug everything
1:45:53
we are now going to plug everything
1:45:53
we are now going to plug everything together so all of these terms
1:45:55
together so all of these terms
1:45:55
together so all of these terms multiplied with all of these terms and
1:45:57
multiplied with all of these terms and
1:45:57
multiplied with all of these terms and add it up according to this formula and
1:45:59
add it up according to this formula and
1:45:59
add it up according to this formula and that gets a little bit hairy so what
1:46:01
that gets a little bit hairy so what
1:46:01
that gets a little bit hairy so what ends up happening is
1:46:04
ends up happening is
1:46:04
ends up happening is uh
1:46:05
uh
1:46:05
uh you get a large expression and the thing
1:46:08
you get a large expression and the thing
1:46:08
you get a large expression and the thing to be very careful with here of course
1:46:09
to be very careful with here of course
1:46:09
to be very careful with here of course is we are working with a DL by dxi for
1:46:12
is we are working with a DL by dxi for
1:46:12
is we are working with a DL by dxi for specific I here but when we are plugging
1:46:15
specific I here but when we are plugging
1:46:15
specific I here but when we are plugging in some of these terms
1:46:17
in some of these terms
1:46:17
in some of these terms like say
1:46:18
like say
1:46:18
like say um
1:46:19
um
1:46:19
um this term here deal by D signal squared
1:46:22
this term here deal by D signal squared
1:46:22
this term here deal by D signal squared you see how the L by D Sigma squared I
1:46:24
you see how the L by D Sigma squared I
1:46:24
you see how the L by D Sigma squared I end up with an expression and I'm
1:46:26
end up with an expression and I'm
1:46:26
end up with an expression and I'm iterating over little I's here but I
1:46:29
iterating over little I's here but I
1:46:29
iterating over little I's here but I can't use I as the variable when I plug
1:46:31
can't use I as the variable when I plug
1:46:31
can't use I as the variable when I plug in here because this is a different I
1:46:33
in here because this is a different I
1:46:33
in here because this is a different I from this eye
1:46:35
from this eye
1:46:35
from this eye this I here is just a place or like a
1:46:37
this I here is just a place or like a
1:46:37
this I here is just a place or like a local variable for for a for Loop in
1:46:39
local variable for for a for Loop in
1:46:39
local variable for for a for Loop in here so here when I plug that in you
1:46:41
here so here when I plug that in you
1:46:41
here so here when I plug that in you notice that I rename the I to a j
1:46:43
notice that I rename the I to a j
1:46:43
notice that I rename the I to a j because I need to make sure that this J
1:46:45
because I need to make sure that this J
1:46:45
because I need to make sure that this J is not that this J is not this I this J
1:46:48
is not that this J is not this I this J
1:46:48
is not that this J is not this I this J is like like a little local iterator
1:46:50
is like like a little local iterator
1:46:50
is like like a little local iterator over 32 terms and so you have to be
1:46:53
over 32 terms and so you have to be
1:46:53
over 32 terms and so you have to be careful with that when you're plugging
1:46:54
careful with that when you're plugging
1:46:54
careful with that when you're plugging in the expressions from here to here you
1:46:56
in the expressions from here to here you
1:46:56
in the expressions from here to here you may have to rename eyes into J's and you
1:46:58
may have to rename eyes into J's and you
1:46:58
may have to rename eyes into J's and you have to be very careful what is actually
1:47:00
have to be very careful what is actually
1:47:00
have to be very careful what is actually an I with respect to the L by t x i
1:47:04
an I with respect to the L by t x i
1:47:04
an I with respect to the L by t x i so some of these are J's some of these
1:47:07
so some of these are J's some of these
1:47:07
so some of these are J's some of these are I's
1:47:08
are I's
1:47:08
are I's and then we simplify this expression
1:47:11
and then we simplify this expression
1:47:11
and then we simplify this expression and I guess like the big thing to notice
1:47:13
and I guess like the big thing to notice
1:47:13
and I guess like the big thing to notice here is a bunch of terms just kind of
1:47:15
here is a bunch of terms just kind of
1:47:15
here is a bunch of terms just kind of come out to the front and you can
1:47:16
come out to the front and you can
1:47:16
come out to the front and you can refactor them there's a sigma squared
1:47:18
refactor them there's a sigma squared
1:47:18
refactor them there's a sigma squared plus Epsilon raised to the power of
1:47:19
plus Epsilon raised to the power of
1:47:19
plus Epsilon raised to the power of negative three over two uh this Sigma
1:47:21
negative three over two uh this Sigma
1:47:21
negative three over two uh this Sigma squared plus Epsilon can be actually
1:47:23
squared plus Epsilon can be actually
1:47:23
squared plus Epsilon can be actually separated out into three terms each of
1:47:25
separated out into three terms each of
1:47:25
separated out into three terms each of them are Sigma squared plus Epsilon to
1:47:28
them are Sigma squared plus Epsilon to
1:47:28
them are Sigma squared plus Epsilon to the negative one over two so the three
1:47:30
the negative one over two so the three
1:47:30
the negative one over two so the three of them multiplied is equal to this and
1:47:33
of them multiplied is equal to this and
1:47:33
of them multiplied is equal to this and then those three terms can go different
1:47:35
then those three terms can go different
1:47:35
then those three terms can go different places because of the multiplication so
1:47:37
places because of the multiplication so
1:47:37
places because of the multiplication so one of them actually comes out to the
1:47:39
one of them actually comes out to the
1:47:39
one of them actually comes out to the front and will end up here outside one
1:47:42
front and will end up here outside one
1:47:42
front and will end up here outside one of them joins up with this term and one
1:47:45
of them joins up with this term and one
1:47:45
of them joins up with this term and one of them joins up with this other term
1:47:47
of them joins up with this other term
1:47:47
of them joins up with this other term and then when you simplify the
1:47:49
and then when you simplify the
1:47:49
and then when you simplify the expression you'll notice that some of
1:47:51
expression you'll notice that some of
1:47:51
expression you'll notice that some of these terms that are coming out are just
1:47:52
these terms that are coming out are just
1:47:52
these terms that are coming out are just the x i hats
1:47:53
the x i hats
1:47:54
the x i hats so you can simplify just by rewriting
1:47:56
so you can simplify just by rewriting
1:47:56
so you can simplify just by rewriting that
1:47:57
that
1:47:57
that and what we end up with at the end is a
1:47:58
and what we end up with at the end is a
1:47:58
and what we end up with at the end is a fairly simple mathematical expression
1:48:00
fairly simple mathematical expression
1:48:00
fairly simple mathematical expression over here that I cannot simplify further
1:48:02
over here that I cannot simplify further
1:48:02
over here that I cannot simplify further but basically you'll notice that it only
1:48:05
but basically you'll notice that it only
1:48:05
but basically you'll notice that it only uses the stuff we have and it derives
1:48:06
uses the stuff we have and it derives
1:48:06
uses the stuff we have and it derives the thing we need so we have the L by d
1:48:10
the thing we need so we have the L by d
1:48:10
the thing we need so we have the L by d y for all the I's and those are used
1:48:13
y for all the I's and those are used
1:48:13
y for all the I's and those are used plenty of times here and also in
1:48:15
plenty of times here and also in
1:48:15
plenty of times here and also in addition what we're using is these x i
1:48:17
addition what we're using is these x i
1:48:17
addition what we're using is these x i hats and XJ hats and they just come from
1:48:19
hats and XJ hats and they just come from
1:48:19
hats and XJ hats and they just come from the forward pass
1:48:20
the forward pass
1:48:20
the forward pass and otherwise this is a simple
1:48:22
and otherwise this is a simple
1:48:22
and otherwise this is a simple expression and it gives us DL by d x i
1:48:25
expression and it gives us DL by d x i
1:48:25
expression and it gives us DL by d x i for all the I's and that's ultimately
1:48:27
for all the I's and that's ultimately
1:48:27
for all the I's and that's ultimately what we're interested in
1:48:29
what we're interested in
1:48:29
what we're interested in so that's the end of Bachelor backward
1:48:32
so that's the end of Bachelor backward
1:48:32
so that's the end of Bachelor backward pass analytically let's now implement
1:48:34
pass analytically let's now implement
1:48:34
pass analytically let's now implement this final result
1:48:36
this final result
1:48:36
this final result okay so I implemented the expression
1:48:38
okay so I implemented the expression
1:48:38
okay so I implemented the expression into a single line of code here and you
1:48:41
into a single line of code here and you
1:48:41
into a single line of code here and you can see that the max diff is Tiny so
1:48:43
can see that the max diff is Tiny so
1:48:43
can see that the max diff is Tiny so this is the correct implementation of
1:48:44
this is the correct implementation of
1:48:44
this is the correct implementation of this formula now I'll just uh
1:48:48
this formula now I'll just uh
1:48:48
this formula now I'll just uh basically tell you that getting this
1:48:50
basically tell you that getting this
1:48:50
basically tell you that getting this formula here from this mathematical
1:48:52
formula here from this mathematical
1:48:52
formula here from this mathematical expression was not trivial and there's a
1:48:54
expression was not trivial and there's a
1:48:54
expression was not trivial and there's a lot going on packed into this one
1:48:56
lot going on packed into this one
1:48:56
lot going on packed into this one formula and this is a whole exercise by
1:48:58
formula and this is a whole exercise by
1:48:58
formula and this is a whole exercise by itself because you have to consider the
1:49:00
itself because you have to consider the
1:49:00
itself because you have to consider the fact that this formula here is just for
1:49:03
fact that this formula here is just for
1:49:03
fact that this formula here is just for a single neuron and a batch of 32
1:49:05
a single neuron and a batch of 32
1:49:05
a single neuron and a batch of 32 examples but what I'm doing here is I'm
1:49:07
examples but what I'm doing here is I'm
1:49:07
examples but what I'm doing here is I'm actually we actually have 64 neurons and
1:49:10
actually we actually have 64 neurons and
1:49:10
actually we actually have 64 neurons and so this expression has to in parallel
1:49:11
so this expression has to in parallel
1:49:11
so this expression has to in parallel evaluate the bathroom backward pass for
1:49:14
evaluate the bathroom backward pass for
1:49:14
evaluate the bathroom backward pass for all of those 64 neurons in parallel
1:49:16
all of those 64 neurons in parallel
1:49:16
all of those 64 neurons in parallel independently so this has to happen
1:49:18
independently so this has to happen
1:49:18
independently so this has to happen basically in every single
1:49:20
basically in every single
1:49:20
basically in every single um
1:49:20
um
1:49:20
um column of the inputs here
1:49:24
column of the inputs here
1:49:24
column of the inputs here and in addition to that you see how
1:49:26
and in addition to that you see how
1:49:26
and in addition to that you see how there are a bunch of sums here and we
1:49:28
there are a bunch of sums here and we
1:49:28
there are a bunch of sums here and we need to make sure that when I do those
1:49:29
need to make sure that when I do those
1:49:29
need to make sure that when I do those sums that they broadcast correctly onto
1:49:31
sums that they broadcast correctly onto
1:49:31
sums that they broadcast correctly onto everything else that's here
1:49:33
everything else that's here
1:49:33
everything else that's here and so getting this expression is just
1:49:35
and so getting this expression is just
1:49:35
and so getting this expression is just like highly non-trivial and I invite you
1:49:36
like highly non-trivial and I invite you
1:49:36
like highly non-trivial and I invite you to basically look through it and step
1:49:37
to basically look through it and step
1:49:37
to basically look through it and step through it and it's a whole exercise to
1:49:39
through it and it's a whole exercise to
1:49:39
through it and it's a whole exercise to make sure that this this checks out but
1:49:43
make sure that this this checks out but
1:49:43
make sure that this this checks out but once all the shapes are green and once
1:49:45
once all the shapes are green and once
1:49:45
once all the shapes are green and once you convince yourself that it's correct
1:49:46
you convince yourself that it's correct
1:49:46
you convince yourself that it's correct you can also verify that Patrick's gets
1:49:48
you can also verify that Patrick's gets
1:49:48
you can also verify that Patrick's gets the exact same answer as well and so
1:49:50
the exact same answer as well and so
1:49:50
the exact same answer as well and so that gives you a lot of peace of mind
1:49:51
that gives you a lot of peace of mind
1:49:51
that gives you a lot of peace of mind that this mathematical formula is
1:49:53
that this mathematical formula is
1:49:53
that this mathematical formula is correctly implemented here and
1:49:55
correctly implemented here and
1:49:55
correctly implemented here and broadcasted correctly and replicated in
1:49:57
broadcasted correctly and replicated in
1:49:57
broadcasted correctly and replicated in parallel for all of the 64 neurons
1:50:00
parallel for all of the 64 neurons
1:50:00
parallel for all of the 64 neurons inside this bastrum layer okay and
1:50:03
inside this bastrum layer okay and
1:50:03
inside this bastrum layer okay and finally exercise number four asks you to
1:50:05
finally exercise number four asks you to
1:50:05
finally exercise number four asks you to put it all together and uh here we have
1:50:08
put it all together and uh here we have
1:50:08
put it all together and uh here we have a redefinition of the entire problem so
1:50:10
a redefinition of the entire problem so
1:50:10
a redefinition of the entire problem so you see that we reinitialize the neural
1:50:11
you see that we reinitialize the neural
1:50:11
you see that we reinitialize the neural nut from scratch and everything and then
1:50:13
nut from scratch and everything and then
1:50:13
nut from scratch and everything and then here instead of calling loss that
1:50:15
here instead of calling loss that
1:50:15
here instead of calling loss that backward we want to have the manual back
1:50:18
backward we want to have the manual back
1:50:18
backward we want to have the manual back propagation here as we derived It Up
1:50:20
propagation here as we derived It Up
1:50:20
propagation here as we derived It Up Above so go up copy paste all the chunks
1:50:23
Above so go up copy paste all the chunks
1:50:23
Above so go up copy paste all the chunks of code that we've already derived put
1:50:25
of code that we've already derived put
1:50:25
of code that we've already derived put them here and drive your own gradients
1:50:26
them here and drive your own gradients
1:50:26
them here and drive your own gradients and then optimize this neural nut
1:50:28
and then optimize this neural nut
1:50:28
and then optimize this neural nut basically using your own gradients all
1:50:31
basically using your own gradients all
1:50:31
basically using your own gradients all the way to the calibration of The
1:50:33
the way to the calibration of The
1:50:33
the way to the calibration of The Bachelor and the evaluation of the loss
1:50:34
Bachelor and the evaluation of the loss
1:50:34
Bachelor and the evaluation of the loss and I was able to achieve quite a good
1:50:36
and I was able to achieve quite a good
1:50:36
and I was able to achieve quite a good loss basically the same loss you would
1:50:38
loss basically the same loss you would
1:50:38
loss basically the same loss you would achieve before and that shouldn't be
1:50:40
achieve before and that shouldn't be
1:50:40
achieve before and that shouldn't be surprising because all we've done is
1:50:41
surprising because all we've done is
1:50:41
surprising because all we've done is we've really gotten to Lost That
1:50:44
we've really gotten to Lost That
1:50:44
we've really gotten to Lost That backward and we've pulled out all the
1:50:45
backward and we've pulled out all the
1:50:45
backward and we've pulled out all the code
1:50:46
code
1:50:46
code and inserted it here but those gradients
1:50:49
and inserted it here but those gradients
1:50:49
and inserted it here but those gradients are identical and everything is
1:50:50
are identical and everything is
1:50:50
are identical and everything is identical and the results are identical
1:50:52
identical and the results are identical
1:50:52
identical and the results are identical it's just that we have full visibility
1:50:54
it's just that we have full visibility
1:50:54
it's just that we have full visibility on exactly what goes on under the hood
1:50:56
on exactly what goes on under the hood
1:50:56
on exactly what goes on under the hood I'll plot that backward in this specific
1:50:58
I'll plot that backward in this specific
1:50:58
I'll plot that backward in this specific case and this is all of our code this is
1:51:02
case and this is all of our code this is
1:51:02
case and this is all of our code this is the full backward pass using basically
1:51:04
the full backward pass using basically
1:51:04
the full backward pass using basically the simplified backward pass for the
1:51:06
the simplified backward pass for the
1:51:06
the simplified backward pass for the cross entropy loss and the mass
1:51:08
cross entropy loss and the mass
1:51:08
cross entropy loss and the mass generalization so back propagating
1:51:10
generalization so back propagating
1:51:10
generalization so back propagating through cross entropy the second layer
1:51:13
through cross entropy the second layer
1:51:13
through cross entropy the second layer the 10 H nonlinearity the batch
1:51:15
the 10 H nonlinearity the batch
1:51:15
the 10 H nonlinearity the batch normalization
1:51:16
normalization
1:51:16
normalization uh through the first layer and through
1:51:19
uh through the first layer and through
1:51:19
uh through the first layer and through the embedding and so you see that this
1:51:21
the embedding and so you see that this
1:51:21
the embedding and so you see that this is only maybe what is this 20 lines of
1:51:23
is only maybe what is this 20 lines of
1:51:23
is only maybe what is this 20 lines of code or something like that and that's
1:51:25
code or something like that and that's
1:51:25
code or something like that and that's what gives us gradients and now we can
1:51:27
what gives us gradients and now we can
1:51:27
what gives us gradients and now we can potentially erase losses backward so the
1:51:30
potentially erase losses backward so the
1:51:30
potentially erase losses backward so the way I have the code set up is you should
1:51:31
way I have the code set up is you should
1:51:31
way I have the code set up is you should be able to run this entire cell once you
1:51:33
be able to run this entire cell once you
1:51:33
be able to run this entire cell once you fill this in and this will run for only
1:51:36
fill this in and this will run for only
1:51:36
fill this in and this will run for only 100 iterations and then break
1:51:37
100 iterations and then break
1:51:37
100 iterations and then break and it breaks because it gives you an
1:51:39
and it breaks because it gives you an
1:51:39
and it breaks because it gives you an opportunity to check your gradients
1:51:41
opportunity to check your gradients
1:51:41
opportunity to check your gradients against pytorch
1:51:43
against pytorch
1:51:43
against pytorch so here our gradients we see are not
1:51:46
so here our gradients we see are not
1:51:46
so here our gradients we see are not exactly equal they are approximately
1:51:48
exactly equal they are approximately
1:51:48
exactly equal they are approximately equal and the differences are tiny
1:51:50
equal and the differences are tiny
1:51:51
equal and the differences are tiny wanting negative 9 or so and I don't
1:51:52
wanting negative 9 or so and I don't
1:51:52
wanting negative 9 or so and I don't exactly know where they're coming from
1:51:54
exactly know where they're coming from
1:51:54
exactly know where they're coming from to be honest
1:51:56
to be honest
1:51:56
to be honest um so once we have some confidence that
1:51:57
um so once we have some confidence that
1:51:57
um so once we have some confidence that the gradients are basically correct we
1:51:59
the gradients are basically correct we
1:51:59
the gradients are basically correct we can take out the gradient tracking
1:52:01
can take out the gradient tracking
1:52:01
can take out the gradient tracking we can disable this breaking statement
1:52:05
we can disable this breaking statement
1:52:05
we can disable this breaking statement and then we can
1:52:07
and then we can
1:52:07
and then we can basically disable lost of backward we
1:52:10
basically disable lost of backward we
1:52:10
basically disable lost of backward we don't need it anymore it feels amazing
1:52:13
don't need it anymore it feels amazing
1:52:13
don't need it anymore it feels amazing to say that
1:52:14
to say that
1:52:14
to say that and then here when we are doing the
1:52:16
and then here when we are doing the
1:52:16
and then here when we are doing the update we're not going to use P dot grad
1:52:18
update we're not going to use P dot grad
1:52:18
update we're not going to use P dot grad this is the old way of pytorch we don't
1:52:21
this is the old way of pytorch we don't
1:52:21
this is the old way of pytorch we don't have that anymore because we're not
1:52:22
have that anymore because we're not
1:52:22
have that anymore because we're not doing backward we are going to use this
1:52:25
doing backward we are going to use this
1:52:25
doing backward we are going to use this update where we you see that I'm
1:52:27
update where we you see that I'm
1:52:27
update where we you see that I'm iterating over
1:52:29
iterating over
1:52:29
iterating over I've arranged the grads to be in the
1:52:30
I've arranged the grads to be in the
1:52:30
I've arranged the grads to be in the same order as the parameters and I'm
1:52:32
same order as the parameters and I'm
1:52:32
same order as the parameters and I'm zipping them up the gradients and the
1:52:34
zipping them up the gradients and the
1:52:34
zipping them up the gradients and the parameters into p and grad and then here
1:52:37
parameters into p and grad and then here
1:52:37
parameters into p and grad and then here I'm going to step with just the grad
1:52:38
I'm going to step with just the grad
1:52:38
I'm going to step with just the grad that we derived manually
1:52:40
that we derived manually
1:52:40
that we derived manually so the last piece
1:52:43
so the last piece
1:52:43
so the last piece um is that none of this now requires
1:52:46
um is that none of this now requires
1:52:46
um is that none of this now requires gradients from pytorch and so one thing
1:52:49
gradients from pytorch and so one thing
1:52:49
gradients from pytorch and so one thing you can do here
1:52:51
you can do here
1:52:51
you can do here um
1:52:52
um
1:52:52
um is you can do with no grad and offset
1:52:56
is you can do with no grad and offset
1:52:56
is you can do with no grad and offset this whole code block
1:52:58
this whole code block
1:52:58
this whole code block and really what you're saying is you're
1:52:59
and really what you're saying is you're
1:52:59
and really what you're saying is you're telling Pat George that hey I'm not
1:53:00
telling Pat George that hey I'm not
1:53:00
telling Pat George that hey I'm not going to call backward on any of this
1:53:02
going to call backward on any of this
1:53:02
going to call backward on any of this and this allows pytorch to be a bit more
1:53:03
and this allows pytorch to be a bit more
1:53:03
and this allows pytorch to be a bit more efficient with all of it
1:53:05
efficient with all of it
1:53:05
efficient with all of it and then we should be able to just uh
1:53:07
and then we should be able to just uh
1:53:07
and then we should be able to just uh run this
1:53:09
run this
1:53:09
run this and
1:53:11
and
1:53:11
and it's running
1:53:13
it's running
1:53:13
it's running and you see that losses backward is
1:53:16
and you see that losses backward is
1:53:16
and you see that losses backward is commented out
1:53:17
commented out
1:53:18
commented out and we're optimizing
1:53:20
and we're optimizing
1:53:20
and we're optimizing so we're going to leave this run and uh
1:53:23
so we're going to leave this run and uh
1:53:23
so we're going to leave this run and uh hopefully we get a good result
1:53:25
hopefully we get a good result
1:53:25
hopefully we get a good result okay so I allowed the neural net to
1:53:27
okay so I allowed the neural net to
1:53:27
okay so I allowed the neural net to finish optimization
1:53:28
finish optimization
1:53:28
finish optimization then here I calibrate the bachelor
1:53:31
then here I calibrate the bachelor
1:53:31
then here I calibrate the bachelor parameters because I did not keep track
1:53:33
parameters because I did not keep track
1:53:33
parameters because I did not keep track of the running mean and very variants in
1:53:35
of the running mean and very variants in
1:53:35
of the running mean and very variants in their training Loop
1:53:37
their training Loop
1:53:37
their training Loop then here I ran the loss and you see
1:53:39
then here I ran the loss and you see
1:53:39
then here I ran the loss and you see that we actually obtained a pretty good
1:53:40
that we actually obtained a pretty good
1:53:40
that we actually obtained a pretty good loss very similar to what we've achieved
1:53:42
loss very similar to what we've achieved
1:53:42
loss very similar to what we've achieved before
1:53:43
before
1:53:43
before and then here I'm sampling from the
1:53:45
and then here I'm sampling from the
1:53:45
and then here I'm sampling from the model and we see some of the name like
1:53:47
model and we see some of the name like
1:53:47
model and we see some of the name like gibberish that we're sort of used to so
1:53:49
gibberish that we're sort of used to so
1:53:49
gibberish that we're sort of used to so basically the model worked and samples
1:53:52
basically the model worked and samples
1:53:52
basically the model worked and samples uh pretty decent results compared to
1:53:54
uh pretty decent results compared to
1:53:54
uh pretty decent results compared to what we were used to so everything is
1:53:56
what we were used to so everything is
1:53:56
what we were used to so everything is the same but of course the big deal is
1:53:58
the same but of course the big deal is
1:53:58
the same but of course the big deal is that we did not use lots of backward we
1:54:00
that we did not use lots of backward we
1:54:00
that we did not use lots of backward we did not use package Auto grad and we
1:54:02
did not use package Auto grad and we
1:54:02
did not use package Auto grad and we estimated our gradients ourselves by
1:54:04
estimated our gradients ourselves by
1:54:04
estimated our gradients ourselves by hand
1:54:05
hand
1:54:05
hand and so hopefully you're looking at this
1:54:06
and so hopefully you're looking at this
1:54:06
and so hopefully you're looking at this the backward pass of this neural net and
1:54:08
the backward pass of this neural net and
1:54:08
the backward pass of this neural net and you're thinking to yourself actually
1:54:10
you're thinking to yourself actually
1:54:10
you're thinking to yourself actually that's not too complicated
1:54:12
that's not too complicated
1:54:12
that's not too complicated um
1:54:13
um
1:54:13
um each one of these layers is like three
1:54:15
each one of these layers is like three
1:54:15
each one of these layers is like three lines of code or something like that and
1:54:17
lines of code or something like that and
1:54:17
lines of code or something like that and most of it is fairly straightforward
1:54:18
most of it is fairly straightforward
1:54:18
most of it is fairly straightforward potentially with the notable exception
1:54:20
potentially with the notable exception
1:54:20
potentially with the notable exception of the batch normalization backward pass
1:54:22
of the batch normalization backward pass
1:54:22
of the batch normalization backward pass otherwise it's pretty good okay and
1:54:25
otherwise it's pretty good okay and
1:54:25
otherwise it's pretty good okay and that's everything I wanted to cover for
1:54:26
that's everything I wanted to cover for
1:54:26
that's everything I wanted to cover for this lecture so hopefully you found this
1:54:29
this lecture so hopefully you found this
1:54:29
this lecture so hopefully you found this interesting and what I liked about it
1:54:31
interesting and what I liked about it
1:54:31
interesting and what I liked about it honestly is that it gave us a very nice
1:54:32
honestly is that it gave us a very nice
1:54:33
honestly is that it gave us a very nice diversity of layers to back propagate
1:54:34
diversity of layers to back propagate
1:54:34
diversity of layers to back propagate through and
1:54:36
through and
1:54:36
through and um I think it gives a pretty nice and
1:54:38
um I think it gives a pretty nice and
1:54:38
um I think it gives a pretty nice and comprehensive sense of how these
1:54:39
comprehensive sense of how these
1:54:39
comprehensive sense of how these backward passes are implemented and how
1:54:41
backward passes are implemented and how
1:54:41
backward passes are implemented and how they work and you'd be able to derive
1:54:43
they work and you'd be able to derive
1:54:43
they work and you'd be able to derive them yourself but of course in practice
1:54:45
them yourself but of course in practice
1:54:45
them yourself but of course in practice you probably don't want to and you want
1:54:46
you probably don't want to and you want
1:54:46
you probably don't want to and you want to use the pythonograd but hopefully you
1:54:49
to use the pythonograd but hopefully you
1:54:49
to use the pythonograd but hopefully you have some intuition about how gradients
1:54:50
have some intuition about how gradients
1:54:51
have some intuition about how gradients flow backwards through the neural net
1:54:52
flow backwards through the neural net
1:54:52
flow backwards through the neural net starting at the loss and how they flow
1:54:55
starting at the loss and how they flow
1:54:55
starting at the loss and how they flow through all the variables and all the
1:54:56
through all the variables and all the
1:54:56
through all the variables and all the intermediate results
1:54:58
intermediate results
1:54:58
intermediate results and if you understood a good chunk of it
1:55:00
and if you understood a good chunk of it
1:55:00
and if you understood a good chunk of it and if you have a sense of that then you
1:55:02
and if you have a sense of that then you
1:55:02
and if you have a sense of that then you can count yourself as one of these buff
1:55:03
can count yourself as one of these buff
1:55:03
can count yourself as one of these buff doji's on the left instead of the uh
1:55:06
doji's on the left instead of the uh
1:55:06
doji's on the left instead of the uh those on the right here now in the next
1:55:09
those on the right here now in the next
1:55:09
those on the right here now in the next lecture we're actually going to go to
1:55:10
lecture we're actually going to go to
1:55:10
lecture we're actually going to go to recurrent neural nuts lstms and all the
1:55:13
recurrent neural nuts lstms and all the
1:55:13
recurrent neural nuts lstms and all the other variants of RNs and we're going to
1:55:16
other variants of RNs and we're going to
1:55:16
other variants of RNs and we're going to start to complexify the architecture and
1:55:17
start to complexify the architecture and
1:55:17
start to complexify the architecture and start to achieve better uh log
1:55:19
start to achieve better uh log
1:55:19
start to achieve better uh log likelihoods and so I'm really looking
1:55:21
likelihoods and so I'm really looking
1:55:21
likelihoods and so I'm really looking forward to that and I'll see you then