00:00:00.000 Hi everyone. So today we are once again continuing our implementation of make more.
00:00:04.800 Now so far we've come up to here, multilial perceptrons, and our neural net looked like
00:00:10.240 this, and we were implementing this over the last few lectures. Now I'm sure everyone is
00:00:14.360 very excited to go into recurring neural networks and all of their variants and how they work,
00:00:18.880 and the diagrams look cool and it's very exciting and interesting and we're going to get a better
00:00:21.800 result. But unfortunately I think we have to remain here for one more lecture. And the
00:00:27.440 reason for that is we've already trained this multilial perceptron, right, and we are getting
00:00:31.680 pretty good loss. And I think we have a pretty decent understanding of the architecture and how it
00:00:35.480 works. But the line of code here that I take an issue with is here, lost up backward. That is
00:00:41.760 we are taking PyTorch autograd and using it to calculate all of our gradients along the way.
00:00:47.240 And I would like to remove the use of lost up backward, and I would like us to write our backward
00:00:51.860 pass manually on the level of tensors. And I think that this is a very useful exercise for
00:00:57.160 the following reasons. I actually have an entire blog post on this topic, but I like to call
00:01:02.280 backpropagation a leaky abstraction. And what I mean by that is backpropagation doesn't just make
00:01:08.840 your neural networks just work magically. It's not the case that you can just stack up arbitrary
00:01:13.040 LEGO blocks of differentiable functions and just cross your fingers and back propagate and
00:01:17.160 everything is great. Things don't just work automatically. It is a leaky abstraction in the
00:01:22.160 sense that you can shoot yourself in a foot if you do not understanding its internals. It will
00:01:27.080 magically not work or not work optimally. And you will need to understand how it works under
00:01:32.560 the hood if you're hoping to debug it and if you are hoping to address it in your neural nut.
00:01:36.440 So this blog post here from a while ago goes into some of those examples. So for example,
00:01:42.640 we've already covered them, some of them already. For example, the flat tails of these functions and
00:01:48.200 how you do not want to saturate them too much because your gradients will die. The case of
00:01:53.680 dead neurons, which I've already covered as well. The case of exploding or vanishing gradients in
00:01:59.480 the case of a kernel networks, which we are about to cover. And then also you will often come across
00:02:05.080 some examples in the wild. This is a snippet that I found in a random code base on the internet,
00:02:10.920 where they actually have like a very subtle but pretty major bug in their implementation. And
00:02:16.040 the bug points at the fact that the author of this code does not actually understand back
00:02:20.760 propagation. So what they're trying to do here is they're trying to clip the loss at a certain
00:02:25.160 maximum value. But actually what they're trying to do is they're trying to clip the gradients to
00:02:29.400 have a maximum value instead of trying to clip the loss at a maximum value. And indirectly,
00:02:35.080 they're basically causing some of the outliers to be actually ignored. Because when you clip a
00:02:41.160 loss of an outlier, you are setting its gradient to zero. And so have a look through this and
00:02:47.000 read through it. But there's basically a bunch of subtle issues that you're going to avoid if you
00:02:51.080 actually know what you're doing. And that's why I don't think it's the case that because PyTorch or
00:02:55.880 other frameworks offer autograd, it is okay for us to ignore how it works. Now we've actually already
00:03:02.040 covered autograd and we wrote micrograd. But micrograd was an autograd engine only on the level of
00:03:08.040 individual scalars. So the atoms were single individual numbers. And you know, I don't think
00:03:13.400 it's enough. And I'd like us to basically think about back propagation on level of tensors as well.
00:03:17.720 And so in a summary, I think it's a good exercise. I think it is very, very valuable. You're going to
00:03:22.920 become better at debugging neural networks and making sure that you understand what you're doing.
00:03:27.960 It is going to make everything fully explicit. So you're not going to be nervous about what is
00:03:31.400 hidden away from you. And basically in general, we're going to emerge stronger. And so let's get into
00:03:37.080 it. A bit of a fun historical note here is that today writing your backward pass by hand and
00:03:42.200 manually is not recommended. And no one does it except for the purposes of exercise. But about
00:03:47.400 10 years ago in deep learning, this was fairly standard and in fact pervasive. So at the time,
00:03:52.120 everyone used to write their backward pass by hand manually, including myself. And it's just what
00:03:56.760 you would do. So we used to write backward pass by hand. And now everyone just calls lost backward.
00:04:02.440 We've lost something. I wanted to give you a few examples of this. So here's a 2006 paper from
00:04:09.240 Jeff Hinton and Russell Select Enough in science that was influential at the time. And this was
00:04:15.160 training some architectures called restricted bolts and machines. And basically it's an auto encoder
00:04:20.760 trained here. And this is from roughly 2010. I had a library for training restricted bolts and
00:04:27.480 machines. And this was at the time written in MATLAB. So Python was not used for deep learning
00:04:32.840 pervasively. It was all MATLAB. And MATLAB was this scientific computing package that everyone
00:04:38.920 would use. So we would write MATLAB, which is barely a programming language as well. But it had a
00:04:44.760 very convenient tensor class. And it was this computing environment and you would run here,
00:04:49.160 it would all run on the CPU, of course, but you would have very nice plots to go with it and a
00:04:53.560 built in debugger. And it was pretty nice. Now the code in this package in 2010 that I wrote
00:04:59.080 for fitting research balls and machines, to a large extent is recognizable. But I wanted to
00:05:04.760 show you how you would well, I'm creating the data in the XY batches. I'm initializing the neural
00:05:10.040 nut. So it's got weights and biases just like we're used to. And then this is the training loop,
00:05:15.000 where we actually do the forward pass. And then here, at this time, they didn't even necessarily
00:05:20.360 use back propagation to train neural networks. So this in particular implements contrastive
00:05:25.000 divergence, which estimates a gradient. And then here we take that gradient and use it for a
00:05:31.080 parameter update along lines that we're used to. Yeah, here. But you can see that basically people
00:05:37.800 are meddling with these gradients directly and inline and themselves. It wasn't that common to
00:05:42.440 use an autograd engine. Here's one more example from a paper of mine from 2014, called the Fragment
00:05:48.680 Embeddings. And here what I was doing is I was aligning images and text. And so it's kind of like
00:05:54.440 clip if you're familiar with it. But instead of working on the level of entire images and entire
00:05:58.680 sentences, it was working on the level of individual objects and little pieces of sentences. And I
00:06:03.320 was embedding them and then calculating a very much like a clip like loss. And I dug up the code
00:06:08.120 from 2014 of how I implemented this. And it was already in NumPy and Python. And here I'm
00:06:15.080 implementing the cost function. And it was standards to implement not just the cost, but also the
00:06:20.120 backward pass manually. So here I'm calculating the image embeddings, sentence embeddings, the last
00:06:26.040 function I calculate this course. This is the last function. And then once I have the last function,
00:06:31.960 I do the backward pass right here. So I backward through the loss function and through the neural
00:06:36.760 net and I append regularization. So everything was done by hand manually. And you're just right
00:06:42.600 out the backward pass. And then you would use a gradient checker to make sure that your numerical
00:06:46.920 estimate of the gradient agrees with the one you populated during back propagation. So this was
00:06:51.400 very standard for a long time. But today, of course, it is standard to use an autorad engine.
00:06:55.480 But it was definitely useful. And I think people sort of understood how these neural networks work
00:07:00.520 on a very intuitive level. And so I think it's a good exercise again. And this is where we want to
00:07:04.680 be. Okay, so just as a reminder from our previous lecture, this is the Jupyter Notebook that we
00:07:09.000 implemented at the time. And we're going to keep everything the same. So we're still going to have
00:07:14.280 a two layer multialing perception with a batch normalization layer. So the forward pass will be
00:07:18.840 basically identical to this lecture. But here we're going to get rid of loss.backward. And instead,
00:07:23.240 we're going to write the backward pass manually. Now here's the starter code for this lecture.
00:07:27.720 We are becoming a back prop ninja in this notebook. And the first few cells here are identical to
00:07:34.360 what we are used to. So we are doing some imports loading in the data set and processing the data
00:07:39.240 set. None of this changed. Now here, I'm introducing a utility function that we're going to use later
00:07:44.600 to compare the gradients. So in particular, we are going to have the gradients that we estimate
00:07:48.600 manually ourselves. And we're going to have gradients that PyTorch calculates. And we're going to be
00:07:53.800 checking for correctness, assuming of course that PyTorch is correct. Then here we have the
00:07:59.480 initialization that we are quite used to. So we have our embedding table for the characters,
00:08:04.680 the first layer, second layer, and a batch normalization in between. And here's where we create all the
00:08:09.560 parameters. Now you will note that I changed the initialization a little bit to be small numbers.
00:08:15.240 So normally you would set the biases to be all zero. Here I am setting them to be small random
00:08:19.800 numbers. And I'm doing this because if your variables are initialized to exactly zero,
00:08:25.160 sometimes what can happen is that can mask an incorrect implementation of a gradient.
00:08:29.080 Because when everything is zero, it sort of simplifies and gives you a much simpler expression
00:08:34.600 of the gradient than you would otherwise get. And so by making it small numbers, I'm trying to
00:08:39.160 unmask those potential errors in these calculations. You also notice that I'm using B1 in the first
00:08:46.600 layer. I'm using a bias despite batch normalization right afterwards. So this would typically not be
00:08:52.440 what you do because we talked about the fact that you don't need a bias. But I'm doing this here
00:08:56.360 just for fun, because we're going to have a gradient with respect to it. And we can check that we are
00:09:00.760 still calculating it correctly, even though this bias is asparious. So here I'm calculating a single
00:09:06.760 batch. And then here I'm doing a forward pass. Now you'll notice that the forward pass is significantly
00:09:12.280 expanded from what we are used to. Here the forward pass was just here. Now the reason that
00:09:18.680 the forward pass is longer is for two reasons. Number one here, we just had an f dot cross entropy.
00:09:23.960 But here I am bringing back a explicit implementation loss function. And number two,
00:09:29.080 I broke up the implementation into manageable chunks. So we have a lot, a lot more intermediate
00:09:35.880 tensors along the way in the forward pass. And that's because we are about to go backwards and
00:09:40.200 calculate the gradients in this back propagation from the bottom to the top. So we're going to go
00:09:47.240 upwards. And just like we have, for example, the lock props tensor in a forward pass,
00:09:51.160 in the backward pass, we're going to have a D lock props, which is going to store the
00:09:54.600 derivative of the loss with respect to the lock props tensor. And so we're going to be
00:09:58.440 pretending D to every one of these tensors and calculating it along the way of this back
00:10:03.400 propagation. So as an example, we have a B and raw here, we're going to be calculating a D, B and
00:10:09.080 raw. So here I'm telling PyTorch that we want to retain the grad of all these intermediate values,
00:10:15.960 because here in exercise one, we're going to calculate the backward pass. So we're going to
00:10:20.040 calculate all these D variables and use the CMP function I've introduced above to check our
00:10:25.880 correctness with respect to what PyTorch is telling us. This is going to be exercise one,
00:10:31.000 where we sort of back propagate through this entire graph. Now, just to give you a very quick
00:10:36.040 preview of what's going to happen in exercise two and below, here we have fully broken up the loss
00:10:42.280 and back propagated through it manually in all the little atomic pieces that make it up.
00:10:47.080 But here we're going to collapse the loss into a single cross entropy call. And instead,
00:10:52.120 we're going to analytically derive using math and paper and pencil, the gradient of the loss
00:10:58.840 with respect to the logits. And instead of back propagating through all of its little chunks one
00:11:02.840 at a time, we're just going to analytically drive what that gradient is. And we're going to implement
00:11:07.080 that, which is much more efficient, as we'll see in a bit. Then we're going to do the exact same
00:11:12.440 thing for batch normalization. So instead of breaking up batch warming to all the tiny components,
00:11:17.560 we're going to use pen and paper and mathematics and calculus to derive the gradient through the
00:11:22.920 batch room layer. So we're going to calculate the backward pass through batch room layer in a
00:11:28.120 much more efficient expression, instead of backward propagating through all of its little pieces
00:11:32.120 independently. So it's going to be exercise three. And then exercise four, we're going to put it
00:11:37.960 all together. And this is the full code of training this two layer MLP. And we're going to basically
00:11:43.640 insert our manual backdrop, we're going to take out loss the backward. And you will basically see
00:11:49.160 that you can get all the same results using fully your own code. And the only thing we're using from
00:11:55.800 PyTorch is the torch.tensor to make the calculations efficient. But otherwise, you will understand
00:12:01.960 fully what needs to forward and backward in your alert and train it. And I think that will be awesome.
00:12:06.520 So let's get to it. Okay, so I ran all the cells of this notebook all the way up to here. And I'm
00:12:12.840 going to erase this and I'm going to start implementing backward pass starting with delock
00:12:17.160 problems. So we want to understand what should go here to calculate the gradient of the loss
00:12:21.720 with respect to all the elements of the lockprops tensor. Now I'm going to give away the answer
00:12:26.760 here, but I wanted to put a quick note here that I think would be most pedagogically useful for you.
00:12:31.800 Is to actually go into the description of this video and find the link to this jubyter notebook.
00:12:37.000 You can find it both on GitHub, but you can also find Google collab with it. So you don't have to
00:12:40.680 install anything, you'll just go to a website on Google collab. And you can try to implement
00:12:45.400 these derivatives or gradients yourself. And then if you are not able to come to my video and see
00:12:51.160 me do it. And so work in tandem and try it first yourself and then see me give away the answer.
00:12:57.480 And I think that would be most valuable to you. And that's how I recommend you go through this
00:13:00.280 lecture. So we are starting here with delock props. Now, delock props will hold the derivative of the
00:13:07.320 loss with respect to all the elements of lock props. What is inside lock blocks? The shape of this
00:13:13.640 is 32 by 27. So it's not going to surprise you that delock props should also be an array of size
00:13:20.680 32 by 27, because we want the derivative loss with respect to all of its elements. So the sizes of
00:13:26.200 those are always going to be equal. Now, how does lock props influence the loss? Okay,
00:13:33.960 loss is negative lock props indexed with range of n and yb and then the mean of that. Now,
00:13:41.880 just as a reminder, yb is just basically an array of all the correct indices.
00:13:51.640 So what we're doing here is we're taking the lock props array of size 32 by 27. Right.
00:13:57.560 And then we are going in every single row. And in each row, we are plugging, plugging out the index
00:14:05.320 eight and then 14 and 15 and so on. So we're going down the rows, that's the iterator range of n.
00:14:10.040 And then we are always plugging out the index of the column specified by this tensor yb.
00:14:16.120 So in the zero throw, we are taking the eighth column. In the first row, we're taking the 14th
00:14:21.320 column, etc. And so lock props at this plucks out all those lock probabilities of the correct
00:14:30.440 next character in a sequence. So that's what that does. And the shape of this or the size of it is,
00:14:35.960 of course, 32, because our batch size is 32. So these elements get plucked out. And then their
00:14:43.880 mean and the negative of that becomes loss. So I always like to work with simpler examples to
00:14:50.440 understand the numerical form of derivative. What's going on here is once we've plucked out these
00:14:56.520 examples, we're taking the mean and then the negative. So the loss basically, if I can write it
00:15:03.560 this way is the negative of say a plus b plus c. And the mean of those three numbers would be
00:15:09.560 say, negative with divide three. That would be how we achieve the mean of three numbers a, b, c,
00:15:14.440 although we actually have 32 numbers here. And so what is basically the loss by say like d a,
00:15:21.320 right? Well, if we simplify this expression mathematically, this is negative one over three of a
00:15:26.840 and negative plus negative one over three of b plus negative one over three of c. And so what
00:15:33.480 is the loss by d a? It's just negative one over three. And so you can see that if we don't just have
00:15:38.520 a, b and c, but we have 32 numbers, then d loss by d, you know, every one of those numbers is going
00:15:45.400 to be one over n more generally, because n is the size of the batch 32 in this case. So d loss by
00:15:54.680 delock, lock props is negative one over n in all these places. Now what about the other elements
00:16:03.720 inside lock props? Because lock props is large array, you see that lock props are shaped is
00:16:08.280 32 by 27. But only 32 of them participate in the loss calculation. So what's the derivative of all
00:16:15.560 the other most of the elements that do not get plucked out here? Well, their loss intuitively is
00:16:21.480 zero. Sorry, their, their gradient intuitively is zero. And that's because they did not participate
00:16:26.120 in the loss. So most of these numbers inside this tensor does not feed into the loss. And so if we
00:16:32.280 were to change these numbers, then the loss doesn't change, which is the equivalent of we are saying
00:16:37.640 that the derivative of the loss with respect to them is zero. They don't impact it.
00:16:41.560 So here's a way to implement this derivative. Then we start out with torched at zeros of shape 32
00:16:49.320 by 27. Or let's just say, instead of doing this, because we don't want to hard-code numbers,
00:16:54.040 let's do torched at zeros, like lock props. So basically, this is going to create an array of
00:16:59.560 zeros exactly in the shape of lock props. And then we need to set the derivative negative one over n
00:17:06.120 inside exactly these locations. So here's what we can do. The lock props indexed in the identical way
00:17:12.840 will be just set to negative one over zero, divide n, right, just like we derived here.
00:17:21.240 So now let me erase all of these reasoning. And then this is the candidate derivative
00:17:28.200 for the lock props. Let's uncomment the first line and check that this is correct.
00:17:34.120 Okay, so CMP ran, and let's go back to CMP. And you see that what it's doing is it's calculating if
00:17:41.960 the calculated value by us, which is dt, is exactly equal to t dot grad as calculated by
00:17:48.360 pytorch. And then this is making sure that all the elements are exactly equal, and then converting
00:17:54.440 this to a single Boolean value, because we don't want to Boolean tensor, we just want to Boolean value.
00:18:00.040 And then here, we are making sure that, okay, if they're not exactly equal, maybe they are
00:18:04.440 approximately equal because of some floating point issues, but they're very, very close.
00:18:08.120 So here we are using torch dot all close, which has a little bit of a wiggle available,
00:18:13.320 because sometimes you can get very, very close. But if you use a slightly different calculation,
00:18:18.760 because of floating point arithmetic, you can get a slightly different result. So this is
00:18:24.200 checking if you get an approximately close close result. And then here we are checking the maximum
00:18:29.720 basically the value that has the highest difference. And what is the difference and the
00:18:35.160 absolute value difference between those two. And so we are printing whether we have an exact
00:18:39.160 equality and approximate equality. And what is the largest difference? And so here,
00:18:45.640 we see that we actually have exact equality. And so therefore, of course, we also have an
00:18:50.440 approximate equality. And the maximum difference is exactly zero. So basically, our delock props is
00:18:57.000 exactly equal to what pytorch calculated to be log props dot grad in its back propagation.
00:19:02.920 So so far, we're running pretty well. Okay, so let's now continue our back propagation.
00:19:07.880 We have that log props depends on props through a log. So all the elements of props are being
00:19:14.120 element wise applied log two. Now, if we want deep props, then remember your micro grab training.
00:19:22.120 We have like a log node, it takes in props and creates log props. And deep props will be the
00:19:28.520 local derivative of that individual operation log times the derivative loss with respect to its
00:19:34.520 output, which in this case is delock props. So what is the local derivative of this operation?
00:19:40.120 Well, we are taking log element wise, and we can come here and we can see, well, from all
00:19:44.200 five is your friend, that d by dx of log of x is just simply one of our x. So therefore,
00:19:49.960 in this case, x is problems. So we have the by dx is one over x, which is one of our props.
00:19:55.560 And then this is the local derivative. And then times we want to train it. So this is chain rule,
00:20:00.920 times you log props. Then let me uncomment this and let me run the cell in place. And we see that
00:20:08.360 the derivative of props as we calculated here is exactly correct. And so notice here how this works,
00:20:14.920 props that are props is going to be inverted and then element wise multiplied here. So if
00:20:21.160 your props is very, very close to one, that means your network is currently predicting the
00:20:25.640 character correctly, then this will become one over one and the log props just gets passed through.
00:20:30.680 But if your probabilities are incorrectly assigned, so if the correct character here
00:20:36.280 is getting a very low probability, then 1.0 dividing by it will boost this and then multiply
00:20:44.040 by the all props. So basically, what this line is doing intuitively is it's taking the
00:20:48.600 examples that have a very low probability currently assigned and it's boosting their gradient.
00:20:53.160 You can look at it that way. Next up is count some imp. So we want the derivative of this.
00:21:01.400 Now, let me just pause here and kind of introduce what's happening here in general,
00:21:06.440 because I know it's a little bit confusing. We have the logis that come out of the neural
00:21:09.880 nut. Here, what I'm doing is I'm finding the maximum in each row, and I'm subtracting it for the
00:21:15.800 purpose of numerical stability. And we talked about how if you do not do this, your own numerical
00:21:20.760 issues of some of the logits take on two large values, because we end up exponentiating them.
00:21:25.480 So this is done just for safety numerically. Then here's the exponentiation of all the sort of like
00:21:32.040 logits to create our counts. And then we want to take the sum of these counts and normalize so
00:21:39.000 that all of the problems sum to one. Now here, instead of using one over count sum, I use
00:21:44.280 race to the power of negative one. Mathematically, they are identical. I just found that there's
00:21:49.080 something wrong with the pytorch implementation of the backward pass of division. And it gives
00:21:54.280 like a real result, but that doesn't happen for star star negative one. So I'm using this formula
00:21:59.560 instead. But basically, all that's happening here is we got the logits, we want to exponentiate
00:22:04.840 all of them, and want to normalize the counts to create our probabilities. It's just that it's
00:22:09.640 happening across multiple lines. So now, here, we want to first take the derivative, we want to
00:22:20.600 back propagate into counts a minute, and then into counts as well. So what should be the counts sum
00:22:27.000 now we actually have to be careful here, because we have to scrutinize and be careful with the shapes.
00:22:33.240 So counts that shape, and then counts some in that shape, are different. So in particular,
00:22:41.400 counts is 32 by 27, but this counts them in is 32 by one. And so in this multiplication here,
00:22:48.200 we also have an implicit broadcasting that pytorch will do, because it needs to take this column
00:22:53.560 tensor of 32 numbers and replicate it horizontally 27 times to align these two tensors so it can do
00:22:59.400 an element twice multiply. So really what this looks like is the following using a toy example again.
00:23:05.000 What we really have here is just props is counts times consumption. So it's a equals a times b,
00:23:11.000 but a is three by three, and b is just three by one, a column tensor. And so pytorch internally
00:23:17.640 replicated this elements of b, and it did that across all the columns. So for example, b one,
00:23:23.720 which is the first element of b, would be replicated here across all the columns in this multiplication.
00:23:29.320 And now we're trying to back propagate through this operation to count some in. So when we are
00:23:34.920 calculating this derivative, it's important to realize that these two, this looks like a single
00:23:40.840 operation, but actually is two operations applied sequentially. The first operation that pytorch did
00:23:46.680 is it took this column tensor and replicated it across all the, across all the columns basically 27
00:23:53.960 times. So that's the first operation is a replication. And then the second operation is the multiplication.
00:23:59.320 So let's first back wrap through the multiplication. If these two arrays were of the same size,
00:24:05.720 and we just have a and b, both of them three by three, then how do we, how do we back propagate
00:24:11.640 through multiplication? So if we just have scalars and not tensors, then if you have c equals a times b,
00:24:17.240 then what is the root of the of c with respect to b? Well, it's just a. And so that's the local
00:24:23.400 derivative. So here in our case, I'm doing the multiplication and back propagate through just
00:24:29.400 multiplication itself, which is element wise, is going to be the local derivative, which in this case
00:24:34.760 is simply counts, because counts is the a. So it's the local derivative, and then times because the
00:24:42.760 chain rule deprops. So this here is the derivative or the gradient, but with respect to replicated b.
00:24:52.040 But we don't have a replicated b, we just have a single b column. So how do we now back propagate
00:24:57.080 through the replication? And intuitively, this b one is the same variable, and it's just reused
00:25:03.480 multiple times. And so you can look at it as being equivalent to a case wave encountered in
00:25:09.640 micro grad. And so here, I'm just pulling out a random graph we used in micro grad. We had an
00:25:14.760 example where a single node has its output feeding into two branches of basically the graph until
00:25:22.280 the last function. And we're talking about how the correct thing to do in the backward pass is
00:25:26.680 we need to sum all the gradients that arrive at any one node. So across these different branches,
00:25:32.280 the gradients would sum. So if a node is used multiple times, the gradients for all of its uses
00:25:39.000 sum during back propagation. So here, b one is used multiple times in all these columns. And
00:25:44.840 therefore, the right thing to do here is to sum horizontally across all the rows. So to sum in
00:25:51.880 dimension one, but we want to retain this dimension so that the so that counts them in,
00:25:58.680 and its gradient are going to be exactly the same shape. So we want to make sure that we keep
00:26:03.160 them as true. So we don't lose this dimension. And this will make the counts some in be exactly
00:26:08.680 shape 32 by one. So revealing this comparison as well and running this, we see that we get an
00:26:16.760 exact match. So this derivative is exactly correct. And let me erase this. Now let's also
00:26:25.800 back propagate into counts, which is the other variable here to create props. So from props to
00:26:31.800 counts, I mean, we just did that, let's go into counts as well. So the counts will be
00:26:37.080 the counts are a. So dc by da is just b. So therefore it's count some and then times
00:26:48.680 chain rule, the props. Now, council is three two by one. The props is 32 by 27. So
00:26:59.720 those will broadcast fine and will give us decounts. There's no additional summation required here.
00:27:05.640 There will be a broadcasting that happens in this multiply here, because council
00:27:11.400 means the replicated again to correctly multiply d props. But that's going to get the correct result.
00:27:17.640 So as far as the single operation is concerned, so we back propagate from props to counts,
00:27:24.200 but we can't actually check the derivative of counts. I have it much later on. And the reason for that
00:27:31.000 is because count some in depends on counts. And so there's a second branch here that we have to finish
00:27:37.160 because council mean back propagates into count some and council will back propagate into counts.
00:27:42.280 And so counts is a node that is being used twice. It's used right here into props and it goes through
00:27:47.640 this other branch through council. So even though we've calculated the first contribution of it,
00:27:53.400 we still have to calculate the second contribution of it later. Okay, so we're continuing with this
00:27:58.280 branch. We have the derivative for council move. Now we want the derivative counts some.
00:28:02.600 So decount sum equals what is the local derivative of this operation? So this is basically an element
00:28:08.920 wise one over counts some. So count sum raised to the power of negative one is the same as
00:28:14.920 one over council. If we go to all from alpha, we see that x to the negative one d by d by d x of it
00:28:21.400 is basically negative x to the negative two. Right one negative one over square is the same as
00:28:26.840 negative x to the negative two. So decount sum here will be local derivative is going to be negative
00:28:34.360 counts sum to the negative two. That's the local derivative times chain rule, which is decount sum
00:28:44.040 in. So that's decount sum. Let's uncomment this and check that I am correct. Okay, so we have perfect
00:28:53.320 equality. And there's no sketching is going on here with any shapes because these are of the same
00:28:59.720 shape. Okay, next up we want to back properly through this line. We have that counts sum is
00:29:04.760 counts that sum along the rows. So I wrote out some help here. We have to keep in mind that counts,
00:29:12.200 of course, is 32 by 27 and counts sum is 32 by one. So in this back propagation, we need to take this
00:29:18.520 column of the rudeness and transform it into a array of derivatives to the original array.
00:29:25.400 So what is this operation doing? We're taking in some kind of an input, like say a three by
00:29:31.240 three matrix A, and we are summing up the rows into a column tensor B, B1 B2 B3, that is basically
00:29:38.760 this. So now we have the derivatives of the loss with respect to B, all the elements of B.
00:29:44.280 And now I want to deliver the loss with respect to all these little a's. So how do the B's depend
00:29:51.800 on the a's is basically what we're after. What is the local derivative of this operation? Well,
00:29:56.440 we can see here that B1 only depends on these elements here. The derivative of B1 with respect
00:30:02.600 to all of these elements down here is zero. But for these elements here, like a one one a one two
00:30:07.960 etc, the local derivative is one, right? So DB one by D a one one, for example is one. So it's one one
00:30:16.280 and one. So when we have the derivative of the loss with respect to B1, the local derivative of B1
00:30:23.080 respect to these inputs is zeros here, but it's one on these guys. So in the chain rule, we have the
00:30:30.360 local derivative times sort of the derivative of B1. And so because the local derivative is one on
00:30:37.560 these three elements, the local derivative multiplying the derivative of B1 will just be
00:30:42.440 the derivative of B1. And so you can look at it as a router. Basically, an addition is a router
00:30:49.480 of gradient. Whatever gradient comes from above, it just gets routed equally to all the elements
00:30:53.880 that participate in that addition. So in this case, the derivative of B1 will just flow equally to
00:31:00.120 the derivative of a one one a one two and a one three. So if we have a derivative of all the
00:31:04.840 elements of B, and in this column tensor, which is D counts some that we've calculated just now,
00:31:11.080 we basically see that what that amounts to is all of these are now flowing to all these elements of A,
00:31:18.280 and they're doing that horizontally. So basically what we want is we want to take the
00:31:23.560 D counts some of size 32 by one, and we just want to replicate it 27 times horizontally to create
00:31:29.880 32 by 27 array. So there's many ways to implement this operation. You could of course just replicate
00:31:35.400 the tensor. But I think maybe one clean one is that D counts is simply torch dot once like
00:31:42.520 so just an two dimensional arrays of once in the shape of counts. So 32 by 27 times D counts
00:31:51.880 sum. So this way we're letting the broadcasting here basically implement the replication. You
00:31:57.640 can look at that way. But then we have to also be careful because D counts was all already
00:32:04.360 calculated. We calculated earlier here. And that was just the first branch. And we're now finishing
00:32:09.640 the second branch. So we need to make sure that these gradients add so plus equals. And then here,
00:32:15.320 let's comment out the comparison. And let's make sure crossing fingers that we have the correct
00:32:23.640 result. So pytorch agrees with us on this gradient as well. Okay, hopefully we're getting a hang of
00:32:29.480 this now. Counts is an element wide exp of normal digits. So now we want D norm logits.
00:32:35.400 And because it's an element has operation, everything is very simple. What is the local
00:32:40.120 derivative of e to the X? It's famously just e to the X. So this is the local derivative.
00:32:45.800 That is the local derivative. Now we already calculated it and it's inside counts. So we
00:32:52.600 made as well potentially just reuse counts. That is the local derivative times D counts.
00:32:58.440 Funny as that looks, constant D counts is iterative on the normal digits. And now let's
00:33:07.000 erase this and let's verify. And it looks good. So that's normal digits. Okay, so we are here on
00:33:16.280 this line now, the normal digits. We have that and we're trying to calculate the logits and
00:33:21.560 the logit maxes. So back propagating through this line. Now we have to be careful here because
00:33:26.600 the shapes again are not the same. And so there's an implicit broadcasting happening here. So normal
00:33:32.440 digits has the shape 32 by 27. Logits does as well. But logit maxes is only 32 by one. So there's
00:33:39.480 a broadcasting here in the minus. Now here I tried to sort of write out a toy example again.
00:33:46.840 We basically have that this is our C equals a minus B. And we see that because of the shape,
00:33:52.280 these are three by three. But this one is just a column. And so for example, every element of C,
00:33:57.320 we have to look at how it came to be. And every element of C is just the corresponding element of
00:34:02.440 A minus basically that associated B. So it's very clear now that the derivatives of every one of
00:34:12.120 these C's with respect to their inputs are one for the corresponding A. And it's a negative one
00:34:19.960 for the corresponding B. And so therefore, the derivatives on the C will flow equally to the
00:34:29.080 corresponding A's. And then also to the corresponding B's. But then in addition to that, the B's are
00:34:34.840 broadcast. So we'll have to do the additional sum just like we did before. And of course,
00:34:39.880 derivatives for B's will undergo a minus, because the local derivative here is a negative one.
00:34:45.320 So DC 32 by D, B three is negative one. So let's just implement that. Basically, D logits will be
00:34:53.240 exactly copying the derivative on normal digits. So D logits equals D normal logits, and I'll do
00:35:03.080 a dot clone for safety. So we're just making a copy. And then we have that D logit maxis
00:35:09.560 will be the negative of D normal logits is of the negative sign. And then we have to be careful
00:35:17.000 because logit maxis is a column. And so just like we saw before, because we keep replicating the
00:35:24.760 same elements across all the columns, then in the backward pass, because we keep reusing this,
00:35:32.200 these are all just like separate branches of use of that one variable. And so therefore,
00:35:36.840 we have to do a sum along one would keep them equals true, so that we don't destroy this dimension.
00:35:42.600 And then the logit maxis will be the same shape. Now we have to be careful because this D logits is
00:35:48.280 not the final D logits. And that's because not only do we get gradient signal into logits through
00:35:54.520 here, but the logit maxis is a function of logits. And that's a second branch into logits.
00:36:00.280 So this is not yet our final derivative for logits. We will come back later for the second branch.
00:36:05.560 For now, the logit maxis is the final derivative. So let me uncomment this CMP here. And let's just
00:36:11.000 run this and logit maxis hit by torch agrees with us. So that was the derivative into through this line.
00:36:19.880 Now, before we move on, I want to pause here briefly. And I want to look at these logit maxis
00:36:25.160 and especially their gradients. We've talked previously in the previous route lecture that the
00:36:30.200 only reason we're doing this is for the numerical stability of the softmax that we are implementing
00:36:34.600 here. And we talked about how if you take these logits for any one of these examples, so one row
00:36:40.200 of this logits tensor, if you add or subtract any value equally to all the elements, then the
00:36:47.000 value of the problems will be unchanged. You're not changing the softmax. The only thing that this
00:36:51.640 is doing is it's making sure that X doesn't overflow. And the reason we're using a max is
00:36:56.600 because then we are guaranteed that each row of logits, the highest number is zero. And so this
00:37:01.960 will be safe. And so basically what that has repercussions. If it is the case that changing
00:37:11.000 logit maxis does not change the props and therefore there's not change the loss, then the gradient
00:37:16.280 on logit masses should be zero, right? Because saying those two things is the same. So indeed,
00:37:22.360 we hope that this is very, very small numbers. Indeed, we hope this is zero. Now, because of
00:37:26.920 floating point sort of wonkiness, this doesn't come out exactly zero, only in some of the rows it does.
00:37:32.680 But we get extremely small values like one negative nine or 10. And so this is telling us that the
00:37:38.440 values of logit maxis are not impacting the loss as they shouldn't. It feels kind of weird to
00:37:43.960 back propagate through this branch, honestly, because if you have any implementation of like F
00:37:50.280 dot cross entropy and pytorch, and you you block together all these elements and you're not doing
00:37:54.440 the back propagation piece by piece, then you would probably assume that the derivative through here
00:37:59.800 is exactly zero. So you would be sort of skipping this branch because it's only done for numerical
00:38:08.200 stability. But it's interesting to see that even if you break up everything into the full atoms and
00:38:13.080 you still do the computation as you'd like with respect to numerical stability, the correct thing
00:38:17.480 happens. And you still get a very, very small gradients here, basically reflecting the fact that the
00:38:23.320 values of these do not matter with respect to the final loss. Okay, so let's now continue
00:38:28.760 back propagation through this line here. We've just calculated the logit maxis, and now we want to
00:38:33.320 back prop into logits through this second branch. Now here, of course, we took logits and we took
00:38:38.760 the max along all the rows. And then we looked at its values here. Now the way this works is that
00:38:44.520 in pytorch. This thing here, the max returns both the values and it returns the indices that we
00:38:52.920 wish those values to call them the maximum value. Now, in the forward pass, we only used values
00:38:58.280 because that's all we needed. But in the backward pass, it's extremely useful to know about where
00:39:03.080 those maximum values occurred. And we have the indices at which they occurred. And this will,
00:39:08.200 of course, helps us to help us do the back propagation. Because what should the backward pass be here
00:39:13.000 in this case? We have the logis tensor, which is 32 by 27. And in each row, we find the maximum
00:39:18.200 value. And then that value gets plucked out into logit maxis. And so intuitively,
00:39:23.560 basically, the derivative flowing through here then should be one times the local derivative is
00:39:33.240 one for the appropriate entry that was plucked out. And then times the global derivative of the
00:39:39.400 logit maxis. So really what we're doing here, if you think through it, is we need to take the
00:39:43.800 delogit maxis, and we need to scatter it to the correct positions in these logits from where the
00:39:51.640 maximum values came. And so I came up with one line of code, sort of that does that. Let me just
00:39:59.080 erase a bunch of stuff here. So the line of, you could do it kind of very similar to what we done
00:40:03.640 here, where we create a zeros, and then we populate the correct elements. So we use the indices here,
00:40:10.120 and we would set them to be one. But you can also use one hat. So at that one hat,
00:40:16.680 and then I'm taking the logis that max over the first dimension that indices. And I'm telling
00:40:22.760 PyTorch that the dimension of every one of these tensors should be 27. And so what this is going to
00:40:31.960 do is, okay, I apologize, this is crazy. Be guilty that I am sure of this. It's really just an array
00:40:40.920 of where the maxis came from in each row. And that element is one, and the old, the other elements
00:40:46.280 are zero. So it's a one-hat vector in each row. And these indices are now populating a single one
00:40:52.520 in the proper place. And then what I'm doing here is I'm multiplying by the logit maxis. And
00:40:58.200 keep in mind that this is a column of 32 by one. And so when I'm doing this times the logit maxis,
00:41:06.840 the logit maxis will broadcast, and that column will get replicated, and then the element-wise
00:41:12.360 multiply will ensure that each of these just gets routed to whichever one of these bits is turned
00:41:18.360 on. And so that's another way to implement this kind of an operation. And both of these can be
00:41:26.360 used. I just thought I would show an equivalent way to do it. And I'm using plus equals because
00:41:30.840 we already calculated the logits here. And this is now the second branch. So let's
00:41:35.640 look at logits and make sure that this is correct. And we see that we have exactly the correct answer.
00:41:43.000 Next up, we want to continue with logits here. That is an outcome of a matrix multiplication
00:41:49.800 and a bias offset in this linear layer. So I've printed out the shapes of all these intermediate
00:41:56.920 tensors. We see that logits is of course 32 by 27, as we've just seen. Then the age here is 32 by 64.
00:42:05.160 So these are 64 dimensional hidden states. And then this w matrix projects those 64 dimensional
00:42:10.840 vectors into 27 dimensions. And then there's a 27 dimensional offset, which is a one dimensional
00:42:17.000 vector. Now we should note that this plus here actually broadcasts, because h multiplied by
00:42:23.560 by w two will give us a 32 by 27. And so then this plus b two is a 27 dimensional vector here.
00:42:31.400 Now in the rules of broadcasting, what's going to happen with this bias vector is that this one
00:42:36.280 dimensional vector of 27 will get aligned with an padded dimension of one on the left. And it will
00:42:43.160 basically become a row vector. And then it will get replicated vertically 32 times to make it 32
00:42:49.000 by 27. And then there's an element wise multiply. Now, the question is how do we back propagate from
00:42:56.600 logits to the hidden states, the weight matrix w two and the bias b two. And you might think that
00:43:03.080 we need to go to some matrix calculus. And then we have to look up the derivative for a matrix
00:43:09.560 multiplication. But actually you don't have to do any of that. And you can go back to first
00:43:13.160 principles and derive this yourself on a piece of paper. And specifically what I like to do and
00:43:18.360 I what I find works well for me is you find a specific small example that you then fully write
00:43:23.800 out. And then in process of analyzing how that individual small example works, you will understand
00:43:28.840 a broader pattern, and you'll be able to generalize and write out the full general formula for how
00:43:35.080 these derivatives flow in an expression like this. So let's try that out. So pardon the low budget
00:43:40.360 production here, but what I've done here is I'm writing it out on a piece of paper. Really,
00:43:44.920 what we are interested in is we have a multiply B plus C. And that creates a D. And we have the
00:43:52.200 derivative of the loss with respect to D. And we'd like to know what the derivative of the loss is
00:43:55.640 with respect to a B and C. Now these here are a little two dimensional examples of a matrix
00:44:01.160 multiplication, two by two times a two by two, plus a two, a vector of just two elements,
00:44:07.720 C one and C two, gives me a two by two. Now notice here that I have a bias vector here called C.
00:44:15.160 And the bias vector C one and C two. But as I described over here, that bias vector will become
00:44:21.160 a row vector in the broadcasting and will replicate vertically. So that's what's happening here as
00:44:25.400 well, C one C two is replicated vertically. And we see how we have two rows of C one C two as a result.
00:44:31.720 So now when I say write it out, I just mean like this, basically break up this matrix multiplication
00:44:39.000 into the actual thing that that's going on under the hood. So as a result of matrix multiplication
00:44:44.520 and how it works, D one one is the result of a dot product between the first row of A and the first
00:44:49.960 column of B. So a one D one one plus a one two B two one plus C one. And so on, so forth for all
00:44:59.480 the other elements of D. And once you actually write it out, it becomes obvious this is just a
00:45:03.880 bunch of multiplies and ads. And we know from micro grad how to differentiate multiplies and ads.
00:45:10.520 And so this is not scary anymore. It's not just matrix multiplication is just tedious,
00:45:15.640 unfortunately, but this is completely tractable. We have DL by D for all of these. And we want
00:45:21.320 DL by all these little other variables. So how do we achieve that? And how do we actually get the
00:45:26.840 gradients? Okay, so the low budget production continues here. So let's for example derive the
00:45:32.680 derivative of the loss with respect to a one one. We see here that a one one occurs twice in our
00:45:38.600 simple expression right here right here. And influence is D one one and D one two. So this is,
00:45:43.880 so what is DL by D? A one one? Well, it's DL by D one one times the local derivative of D one one,
00:45:52.600 which in this case is just B one one, because that's what's multiplying a one one here. So,
00:45:58.040 and likewise here the local derivative of D one two with respect to a one one is just B one two.
00:46:03.880 And so B one two will in the chain rule therefore multiply DL by D one two. And then because A one one
00:46:10.360 is used both to produce D one one and D one two, we need to add up the contributions of both of
00:46:17.320 those sort of chains that are running in parallel. And that's why we get a plus just adding up those
00:46:23.000 two, those two contributions. And that gives us DL by D a one one. We can do the exact same analysis
00:46:29.880 for the other one for all the other elements of a. And when you simply write it out, it's just super
00:46:35.880 simple, taking ingredients on, you know, expressions like this, you find that this matrix DL by D a
00:46:46.200 that we're after, right, if we just arrange all of them in the same shape as a takes. So a is just
00:46:52.440 two or two matrix. So DL by D a here will be also just the same shape tensor with the derivatives now.
00:47:02.680 So DL by D a one one, etc. And we see that actually we can express what we've written out here as a
00:47:09.160 matrix multiply. And so it just so happens that DL by that all of these formulas that we've derived
00:47:15.720 here by taking gradients can actually be expressed as a matrix multiplication. And in particular,
00:47:20.680 we see that it is the matrix multiplication of these two matrices. So it is the DL by D.
00:47:29.160 And then matrix multiplying B, but B transpose, actually. So you see that B two one and B one two
00:47:36.120 have changed place. Whereas before we had, of course, B one one B one two, B two one B two two.
00:47:42.520 So you see that this other matrix B is transposed. And so basically what we have on story short,
00:47:49.400 just by doing very simple reasoning here by breaking up the expression in the case of a very
00:47:53.960 simple example, is that DL by D a is which is this is simply equal to DL by D D matrix multiplied
00:48:02.440 with B transpose. So that is what we have so far. Now we also want the derivative with respect to
00:48:09.640 B and C. Now, for B, I'm not actually doing the full derivation because honestly it's
00:48:17.080 it's not deep. It's just annoying. It's exhausting. You can actually do this analysis yourself.
00:48:23.400 You'll also find that if you take this these expressions and you differentiate with respect
00:48:27.160 to B instead of A, you will find that DL by D B is also a matrix multiplication. In this case,
00:48:33.240 you have to take the matrix A and transpose it. And matrix multiply that with DL by D D.
00:48:38.280 And that's what gives you a deal by D B. And then here for the offsets C one and C two,
00:48:46.520 if you again just differentiate with respect to C one, you will find an expression like this.
00:48:52.120 And C two, an expression like this. And basically you'll find that DL by D C is simply because
00:48:58.680 they're just offsetting these expressions. You just have to take the DL by D D matrix
00:49:03.080 of the derivatives of D. And you just have to sum across the columns. And that gives you the
00:49:11.320 derivatives for C. So long story short, the backward pass of a matrix multiply is a matrix multiply.
00:49:20.040 And instead of just like we had D equals A times B plus C in a scalar case, we sort of like
00:49:26.360 arrive at something very, very similar, but now with a matrix multiplication instead of a scalar
00:49:31.000 multiplication. So the derivative of D with respect to A is DL by D D matrix multiply B
00:49:40.440 transpose. And here it's a transpose multiply DL by D D. But in both cases, matrix multiplication
00:49:47.160 with the derivative and the other term in the multiplication. And for C it is a sum.
00:49:54.760 Now I'll tell you a secret. I can never remember the formulas that we just arrived for back
00:50:00.760 propagate information multiplication. And I can back propagate through these expressions just fine.
00:50:04.840 And the reason this works is because the dimensions have to work out. So let me give you an example.
00:50:11.000 Say I want to create DH, then what should DH be number one? I have to know that the shape of DH
00:50:18.680 must be the same as the shape of H. And the shape of H is 30 to by 64. And then the other piece of
00:50:24.600 information I know is that DH must be some kind of matrix multiplication of D logits with W2.
00:50:31.640 And D logits is 32 by 27. And W2 is 64 by 27. There is only a single way to make the shape
00:50:40.840 or count in this case. And it is indeed the correct result. In particular here, H needs to be 32 by
00:50:48.120 64. The only way to achieve that is to take a D logits and matrix multiply it with you see how
00:50:55.720 I have to take W2, but I have to transpose it to make the dimensions work out. So W2 transpose.
00:51:01.800 And it's the only way to make these two matrix multiply those two pieces to make the shapes
00:51:05.960 work out. And that turns out to be the correct formula. So if we come here, we want DH, which is
00:51:11.960 DA. And we see that DA is DL by DD matrix multiply B transpose. So that's D logits multiply and B
00:51:20.360 is W2. So W2 transpose, which is exactly what we have here. So there's no need to remember these
00:51:26.280 formulas. Similarly, now if I want D W2, well, I know that it must be a matrix multiplication of
00:51:35.240 D logits and H. And maybe there's a few transpose, or there's one transpose in there as well. And I
00:51:41.080 don't know which way it is. So I have to come to W2. And I see that it's shape is 64 by 27.
00:51:45.960 And that has to come from some matrix multiplication of these two. And so to get a 64 by 27, I need to
00:51:54.040 take H, I need to transpose it. And then I need to matrix multiply it. So that will become 64 by 32.
00:52:03.720 And then I need to make sure it's multiplying with the 32 by 27. And that's going to give me a 64 by 27.
00:52:08.760 So I need to make sure it's multiplied this with the logis that shape, just like that. That's the
00:52:12.840 only way to make the dimensions work out. And just use matrix multiplication. And if we come here,
00:52:18.520 we see that that's exactly what's here. So a transpose, a for us is H, multiply with the logis.
00:52:25.960 So that's W2. And then DB2 is just the vertical sum. And actually, in the same way, there's only
00:52:36.360 one way to make the shapes work out. I don't have to remember that it's a vertical sum along the
00:52:40.520 zero of axis, because that's the only way that this makes sense. Because B2 shape is 27. So in
00:52:46.600 order to get a delogits here, it's 32 by 27. So knowing that it's just some over delogits,
00:52:55.640 in some direction, that direction must be zero, because I need to eliminate this dimension.
00:53:02.920 So it's this. So this is, so this kind of like the hacky way, let me copy paste and delete that.
00:53:11.080 And let me swing over here. And this is our backward pass for the linear layer, hopefully.
00:53:16.360 So now let's uncomment these three. And we're checking that we got all the three
00:53:23.240 derivatives correct, and run. And we see that H, W2 and B2 are all exactly correct. So we back
00:53:32.200 propagated through a linear layer. Now next up, we have derivative for the H already. And we need
00:53:40.120 to back propagate through 10 H into HP react. So we want to derive the HP react. And here we have
00:53:47.720 to back propagate through a 10 H. And we've already done this in microgram. And we remember that 10
00:53:52.280 H is a very simple backward formula. Now, unfortunately, if I just put in D by DX of 10 H of X into
00:53:57.880 Wotferm alpha, it lets us down. It tells us that it's a hyperbolic secant function squared of X.
00:54:03.640 It's not exactly helpful. But luckily, Google image search does not let us down. And it gives
00:54:08.920 us the simpler formula. And in particular, if you have that a is equal to 10 H of Z, then D A by D
00:54:15.240 Z, back propagating through 10 H, is just one minus a square. And take note that one minus a square
00:54:22.040 A here is the output of the 10 H, not the input to the 10 H Z. So the D A by D Z is here formulated
00:54:29.480 in terms of the output of that 10 H. And here also in Google image search, we have the full
00:54:34.280 derivation. If you want to actually take the actual definition of 10 H and work through the math to
00:54:39.480 figure out one minus 10 square of Z. So one minus a square is the local derivative. In our case,
00:54:46.920 that is one minus the output of 10 H square, which here is H. So it's H square. And that is the local
00:54:56.120 derivative. And then times the chain rule, DH. So that is going to be our candidate implementation.
00:55:03.320 So if we come here, and then uncomment this, let's hope for the best. And we have the right answer.
00:55:12.360 Okay, next up, we have BHP Act, and we want to back propagating to the gain, the B and raw,
00:55:17.560 and the B and bias. So here, this is the best-term parameters, B and gain and bias inside the
00:55:22.840 best-term that take the B and raw, that is exact unit Gaussian, and they scale it and shift it.
00:55:28.440 And these are the parameters of the best-term. Now, here, we have a multiplication, but it's
00:55:34.760 worth noting that this multiply is very, very different from this matrix multiply here.
00:55:38.920 Matrix multiply are dot products between rows and columns of these matrices involved.
00:55:43.560 This is an element twice multiply. So things are quite a bit simpler. Now, we do have to be
00:55:48.760 careful with some of the broadcasting happening in this line of code though. So you see how B and
00:55:53.400 gain and B and bias are one by 64, but H preact and B and raw are 32 by 64.
00:56:00.680 So we have to be careful with that and make sure that all the shapes work out fine and that the
00:56:05.880 broadcasting is correctly back propagated. So in particular, let's start with the B and gain.
00:56:10.440 So D, B and gain should be. And here, this is again, element twice multiply. And whenever we have
00:56:18.280 A times B equals C, we saw that the local derivative here is just, if this is A, the local derivative
00:56:23.960 is just the B, the other. So the local derivative is just B and raw, and then times chain rule.
00:56:31.480 So D, H, preact. So this is the candidate gradient. Now again, we have to be careful because B and
00:56:40.520 gain is of size one by 64. But this here would be 32 by 64. And so the correcting to do in this
00:56:50.920 case, of course, is that B and gain here is a rule vector of 64 numbers. It gets replicated
00:56:56.120 vertically in this operation. And so therefore, the correcting to do is to sum, because it's
00:57:02.120 being replicated. And therefore, all the gradients in each of the rows that are now flowing backwards
00:57:08.360 need to sum up to that same tensor B and gain. So if the sum across all the zero, all the examples
00:57:16.680 basically, which is the direction which just gets replicated. And now we have to be also careful
00:57:21.560 because we be in gain is of shape one by 64. So in fact, I need to keep them as true. Otherwise,
00:57:29.800 I would just get 64. Now I don't actually really remember why the B and gain and the B and bias,
00:57:36.280 I made them be one by 64. But the biases be one and B two, I just made them be one dimensional
00:57:44.680 vectors, they're not two dimensional tensors. So I can't recall exactly why I left the gain and
00:57:51.000 the bias as two dimensional. But it doesn't really matter as long as you are consistent and you're
00:57:55.160 keeping it the same. So in this case, we want to keep the dimension so that the tensor shapes work.
00:57:59.480 Next up, we have B and raw. So D B and raw will be B and gain, multiplying D H preact. That's our
00:58:13.640 chain rule. Now what about the dimensions of this? We have to be careful, right? So D H preact is
00:58:22.440 32 by 64. B and gain is one by 64. So we'll just get replicated and to create this multiplication,
00:58:30.840 which is the correct thing, because in a forward pass, it also gets replicated in just the same way.
00:58:35.080 So in fact, we don't need the brackets here, we're done. And the shapes are already correct.
00:58:40.840 And finally, for the bias, very similar, this bias here is very, very similar to the bias we
00:58:46.760 saw in the linear in the linear layer. And we see that the gradients from H preact will simply flow
00:58:52.360 into the biases and add up, because these are just these are just offsets. And so basically,
00:58:57.720 we want this to be D H preact, but it needs to sum along the right dimension. And in this case,
00:59:03.800 similar to the gain, we need to sum across the zero dimension, the examples, because of the way
00:59:09.000 that the bias gets replicated very quickly. And we also want to have keep them as true.
00:59:14.440 And so this will basically take this and sum it up and give us a one by 64. So this is the
00:59:21.720 candidate implementation makes all the shapes work. Let me bring it up down here. And then let me
00:59:29.160 uncomment these three lines to check that we are getting the correct result for all the three
00:59:34.840 tensors. And indeed, we see that all of that got back propagated correctly. So now we get to the
00:59:40.280 batch norm layer. We see how here being gained and being biased are the parameters. So the back
00:59:45.400 propagation ends. But being raw now is the output of the standardization. So here, what I'm doing,
00:59:52.600 of course, is I'm breaking up the batch norm into manageable pieces. So we can back propagate
00:59:56.200 through each line individually. But basically, what's happening is a B and mean I is the sum.
01:00:03.880 So this is the B and mean I apologize for the variable naming. B and diff is X minus mu.
01:00:10.520 B and diff two is X minus mu squared here inside the variance. B and var is the variance. So
01:00:19.160 sigma squared, this is B and var. And it's basically the sum of squares. So this is the
01:00:26.440 X minus mu squared, and then the sum. Now you'll notice one departure here. Here, it is normalized
01:00:33.240 as one over M, which is the number of examples. Here, I'm normalizing as one over N minus one
01:00:39.720 instead of M. And this is deliberate. I'll come back to that in a bit when we are at this line.
01:00:44.600 It is something called the bestless correction. But this is how I want it in our case.
01:00:49.560 B and var in then becomes basically B and var plus epsilon. Epsilon is one negative five. And
01:00:57.800 then it's one over square root is the same as raising to the power of negative point five.
01:01:03.720 Right? Because point five is a square root. And then negative makes it one over square root.
01:01:08.040 So B and var M is a one over this denominator here. And then we can see that B and var, which is the
01:01:15.560 X hat here, is equal to the B and diff, the numerator, multiplied by the B and var in.
01:01:24.840 And this line here that creates pre-H pre-act was the last piece we've already back propagated
01:01:29.560 through it. So now what we want to do is we are here and we have B and raw, and we have to first
01:01:36.360 back propagate into B and diff and B and var in. So now we're here and we have D, B and raw.
01:01:42.680 And we need to back propagate through this line. Now I've written out the shapes here,
01:01:48.760 and indeed B and var in is a shape one by 64. So there is a broadcasting happening here that we
01:01:55.480 have to be careful with. But it is just an element wise simple multiplication. By now we should be
01:02:00.040 pretty comfortable with that. To get D, B and diff, we know that this is just B and var in,
01:02:05.480 multiplied with D, B and raw. And conversely, to get D, B and var in, we need to take B and diff
01:02:17.800 and multiply that by D, B and raw. So this is the candidate, but of course we need to make sure
01:02:25.400 that broadcasting is obeyed. So in particular, B and var in multiplying with D, B and raw
01:02:30.920 will be okay and give us 32 by 64 as we expect. But D, B and var in would be taking a 32 by 64,
01:02:42.360 multiplying it by 32 by 64. So this is a 32 by 64. But of course, D, B, this B and var in is only
01:02:50.520 one by 64. So the second line here needs a sum across the examples. And because there's this
01:02:57.880 dimension here, we need to make sure that keep them history. So this is the candidate.
01:03:04.920 Let's erase this and let's swing down here and implement it. And then let's comment out D, B and
01:03:12.360 var in and D, B and diff. Now, we'll actually notice that D, B and diff, by the way, is going to be
01:03:20.680 incorrect. So when I run this, B and var in this correct, B and diff is not correct. And this is
01:03:29.160 actually expected, because we're not done with B and diff. So in particular, when we slide here,
01:03:35.480 we see here that B and raw is a function of B and diff. But actually, B and var is a function
01:03:41.000 of B and var, which is a function of B and diff, too, which is a function of B and diff. So it comes
01:03:45.720 here. So B, D and diff, these variable names are crazy. I'm sorry, it branches out into two branches,
01:03:53.240 and we've only done one branch of it. We have to continue our back propagation and eventually
01:03:57.320 come back to B and diff. And then we'll be able to do a plus equals and get the actual correct
01:04:01.960 gradient. For now, it is good to verify that CBMP also works. It doesn't just lie to us and tell
01:04:07.160 us that everything is always correct. It can in fact detect when your gradient is not correct.
01:04:12.920 So it's that's good to see as well. Okay, so now we have the derivative here, and we're trying to
01:04:17.080 back propagate through this line. And because we're raising to a power of negative point five,
01:04:21.560 I brought up the power rule. And we see that basically we have that the B and var will now be,
01:04:27.320 we bring down the exponent. So negative point five times x, which is this. And now raise to the
01:04:35.560 power of negative point five minus one, which is a negative 1.5. Now, we would have to also
01:04:41.880 apply a small chain rule here in our head, because we need to take further derivative of
01:04:47.640 B and var with respect to this expression here inside the bracket. But because this is an
01:04:51.880 element wise operation, and everything is fairly simple, that's just one. And so there's nothing
01:04:56.280 to do there. So this is the local derivative. And then times the global derivative to create the
01:05:01.800 chain rule. This is just times the B and var. So this is our candidate. Let me bring this down.
01:05:10.440 And uncommon to the check. And we see that we have the correct result. Now, before we back
01:05:18.040 propagate through the next line, I want to briefly talk about the note here, where I'm using the
01:05:21.800 bestness correction, dividing by n minus one, instead of dividing by n, when I normalize here,
01:05:27.400 the sum of squares. Now, you'll notice that this is the departure from the paper, which uses one
01:05:32.440 over n instead, not one over n minus one. There m is rn. And so it turns out that there are two
01:05:40.440 ways of estimating variance of an array. One is the biased estimate, which is one over n.
01:05:47.480 And the other one is the unbiased estimate, which is one over n minus one. Now, confusingly,
01:05:52.200 in the paper, this is not very clearly described. And also, it's a detail that kind of matters, I think.
01:05:58.840 They are using the biased version of train time. But later, when they are talking about the inference,
01:06:03.560 they are mentioning that when they do the inference, they are using the unbiased estimate,
01:06:09.320 which is the n minus one version in basically four inference. And to calibrate the running mean
01:06:17.960 and running variance, basically. And so they actually introduce a train test mismatch,
01:06:23.320 where in training, they use the biased version. And in the, and test time, they use the unbiased
01:06:27.400 version. I find this extremely confusing. You can read more about the Bessel's correction,
01:06:32.040 and why dividing by n minus one gives you a better estimate of the variance. In a case where you
01:06:37.560 have population size, or samples for a population, they are very small. And that is indeed the case
01:06:44.440 for us, because we are dealing with many batches. And these minimatches are a small sample of a
01:06:49.720 larger population, which is the entire training set. And so it just turns out that if you just
01:06:55.080 estimate it using one over n, that actually almost always underestimates the variance. And it is a
01:07:00.360 biased estimator, and it is advised that you use the unbiased version and divide by n minus one.
01:07:05.160 And you can go through this article here that I liked, that actually describes the full reasoning,
01:07:09.480 and I'll link it in the video description. Now, when you calculate the torso variance,
01:07:14.040 you'll notice that they take the unbiased flag, whether or not you want to divide by n,
01:07:18.760 or n minus one. Confusingly, they do not mention what the default is for unbiased. But I believe
01:07:26.120 unbiased by default is true. I'm not sure why the docs here don't cite that. Now, in the Bessel
01:07:32.440 norm, 1D, the documentation again is kind of wrong and confusing. It says that the standard
01:07:38.280 deviation is calculated via the biased estimator. But this is actually not exactly right. And people
01:07:43.400 have pointed out that it is not right in a number of issues since then. Because actually,
01:07:48.360 the rabbit hole is deeper, and they follow the paper exactly, and they use the biased version
01:07:53.880 for training. But when they're estimating the running standard deviation, we are using the unbiased
01:07:58.920 version. So again, there's the train test mismatch. So long story short, I'm not a fan of train
01:08:04.680 test discrepancies. I basically kind of consider the fact that we use the biased version, the
01:08:10.920 training time, and the unbiased test time, I basically consider this to be a bug. And I don't
01:08:15.560 think that there's a good reason for that. It's not really, they don't really go into the detail of
01:08:19.640 the reasoning behind it in this paper. So that's why I basically prefer to use the Bessel's correction
01:08:25.320 in my own work. Unfortunately, Bessel norm does not take a keyword argument that tells you whether or
01:08:30.280 not you want to use the unbiased version or the biased version in both training tests. And so
01:08:35.720 therefore anyone using Bessel normalization, basically in my view has a bit of a bug in the code.
01:08:42.040 And this turns out to be much less of a problem if your batch mini batch sizes are a bit larger.
01:08:47.560 But still, I just find kind of a unpodable. So maybe someone can explain why this is okay.
01:08:52.840 But for now, I prefer to use the unbiased version consistently both during training and at test time.
01:08:58.360 And that's why I'm using one over n minus one here. Okay, so let's now actually back propagate
01:09:03.320 through this line. So the first thing that I always like to do is I like to scrutinize the shapes
01:09:09.960 first. So in particular here, looking at the shapes of what's involved, I see that B and var
01:09:15.160 shape is one by 64. So it's a row vector and B and if two dot shape is 32 by 64.
01:09:21.800 So clearly here we're doing a sum over the zero axis to squash the first dimension of
01:09:29.800 the shapes here using a sum. So that right away actually hints to me that there will be some kind
01:09:36.280 of a replication or broadcasting in the backward pass. And maybe you're noticing the pattern here,
01:09:41.080 but basically anytime you have a sum in the forward pass, that turns into a replication
01:09:46.760 or broadcasting in the backward pass along the same dimension. And conversely, when we have a
01:09:51.880 replication or a broadcasting in the forward pass, that indicates a variable reuse. And so
01:09:58.760 in the backward pass, that turns into a sum over the exact same dimension. And so hopefully you're
01:10:03.640 noticing that duality that those two are kind of like the opposite of each other in the forward
01:10:07.320 and backward pass. Now, once we understand the shapes, the next thing I like to do always is I
01:10:12.520 like to look at a toy example in my head to sort of just like understand roughly how the variable
01:10:17.800 the variable dependencies go in the mathematical formula. So here we have a two dimensional array
01:10:23.480 at the end of two, which we are scaling by a constant. And then we are summing vertically
01:10:30.200 over the columns. So if we have a two by two matrix a, and then we sum over the columns and scale,
01:10:35.560 we would get a row vector b1 b2. And b1 depends on a in this way, where it's just some, they're
01:10:42.200 scaled of a and b2 in this way, where it's the second column, sum and scale. And so looking at
01:10:50.280 this basically, what we want to do now is we have the derivatives on b1 and b2, and we want to
01:10:55.480 back propagate them into a's. And so it's clear that just differentiating in your head, the local
01:11:00.840 derivative here is 1 over n minus 1 times 1 for each one of these a's. And basically the derivative
01:11:10.360 of b1 has to flow through the columns of a scaled by 1 over n minus 1. And that's roughly what's
01:11:17.320 happening here. So intuitively, the derivative flow tells us that d bn df2 will be the local
01:11:26.760 derivative of this operation. And there are many ways to do this by the way, but I like to do
01:11:30.760 something like this, torched out one slide of bn df2. So I'll create a large array to the
01:11:37.480 mission of ones. And then I will scale it. So 1.0 divided by n minus 1. So this is a array of
01:11:45.880 1 over n minus 1. And that's sort of like the local derivative. And now for the chain rule,
01:11:52.600 I will simply just multiply it by db and r. And notice here what's going to happen. This is
01:12:00.120 32 by 64. And this is just 1 by 64. So I'm letting the broadcasting do the replication,
01:12:07.080 because internally in PyTorch, basically db and var, which is 1 by 64 row vector,
01:12:12.200 well, in this multiplication, get copied vertically until the two are of the same shape,
01:12:18.680 and then there will be an element wise multiply. And so that the broadcasting is basically doing
01:12:23.800 the replication. And I will end up with the derivatives of the bn df2 here. So this is the
01:12:30.760 kentate solution. Let's bring it down here. Let's uncomment this line where we check it. And let's
01:12:37.640 hope for the best. And indeed, we see that this is the correct formula. Next up, let's differentiate
01:12:43.800 here into bn df. So here we have that bn df is element wise squared to create bn df2. So this is
01:12:51.640 a relatively simple derivative because it's a simple element wise operation. So it's kind of
01:12:55.720 like the scalar case. And we have that db and df should be, if this is x squared, then derivative
01:13:02.280 of this is 2x. Right? So it's simply two times b and df. That's the local derivative. And then
01:13:09.240 times chain rule. And the shape of this is the same, they are of the same shape. So times this.
01:13:15.000 So that's the backward pass for this variable. Let me bring it down here.
01:13:20.600 And now we have to be careful because we already calculated db and df. Right? So this is just the
01:13:25.160 end of the other, you know, other branch coming back to bn df. Because bn df will already backprop
01:13:32.680 get it to way over here from b and raw. So we now completed the second branch. And so that's why I
01:13:39.320 have to do plus equals. And if you recall, we had an incorrect derivative for bn df before.
01:13:44.840 And I'm hoping that once we append this last missing piece, we have the exact correctness.
01:13:49.720 So let's run and bn df to bn df now actually shows the exact correct derivative. So that's
01:13:58.520 comforting. Okay, so let's now back propagate through this line here. The first thing we do,
01:14:04.040 of course, is we check the shapes. And I wrote them out here. And basically, the shape of this
01:14:08.440 is 32 by 64. HPBN is the same shape. But bn mini is a row vector one by 64. So this minus here
01:14:17.080 will actually do broadcasting. And so we have to be careful with that. And as a hint to us, again,
01:14:21.960 because of the duality, a broadcasting in the forward pass means a variable reuse. And therefore,
01:14:27.240 there will be a sum in the backward pass. So let's write out the backward pass here now.
01:14:34.360 Back propagate into the HPBN. Because these are the same shape, then the local derivative for each
01:14:40.680 one of the elements here is just one for the corresponding element in here. So basically,
01:14:45.960 what this means is that the gradient just simply copies is just a variable assignment,
01:14:51.160 its quality. So I'm just going to clone this tensor, just for safety to create an exact copy
01:14:56.760 of DB and df. And then here to back propagate into this one, what I'm inclined to do here is
01:15:04.600 the bn mini will basically be what is the local derivative? Well, it's negative torch dot one,
01:15:14.360 slight of the shape of bn diff. Right. And then times the derivative here, DB and diff.
01:15:32.840 And this here is the back propagation for the replicated bn mini. So I still have to back propagate
01:15:39.800 through the replication in the broadcasting. And I do that by doing a sum. So I'm going to take
01:15:45.000 this whole thing. And I'm going to do a sum over the zero dimension, which was the replication.
01:15:50.200 So if you scrutinize this, by the way, you'll notice that this is the same shape as that.
01:15:58.440 And so what I'm doing, what I'm doing here doesn't actually make that much sense,
01:16:01.880 because it's just a array of ones multiplying the bn diff. So in fact, I can just do this.
01:16:08.360 And that is equivalent. So this is the candidate backward pass. Let me copy it here.
01:16:16.120 And then let me comment out this one and this one. Enter. And it's wrong. Damn.
01:16:29.320 Actually, sorry, this is supposed to be wrong. And it's supposed to be wrong because we are
01:16:34.840 back propagating from a bn diff into H prebn. And but we're not done because bn mini depends on H
01:16:41.880 prebn. And there will be a second portion of that derivative coming from this second branch.
01:16:46.840 So we're not done yet. And we expect it to be incorrect. So there you go. So let's now back
01:16:51.480 propagate from bn mini into H prebn. And so here again, we have to be careful because there's
01:16:59.080 a broadcasting along, or there's a sum along the zero dimension. So this will turn into
01:17:04.440 broadcasting in the backward pass now. And I'm going to go a little bit faster on this line,
01:17:09.000 because it is very similar to the line that we had before, and multiplies in the past, in fact.
01:17:14.040 So the H prebn will be the gradient will be scaled by one over n. And then basically this
01:17:24.520 gradient here, the bn mini is going to be scaled by one over n. And then it's going to flow across
01:17:31.480 all the columns and deposit itself into the H prebn. So what we want is this thing scaled by
01:17:38.280 one over n. We'll put the constant up front here. So scaled on the gradient, and now we need to
01:17:47.720 replicate it across all the across all the rows here. So we I like to do that by torch dot one,
01:17:56.200 slight off basically, H prebn. And I will let the broadcasting do the work of replication. So
01:18:14.920 like that. So this is the H prebn. And hopefully we can plus equals that.
01:18:22.520 So this here is broadcasting. And then this is the scaling. So this should be correct.
01:18:32.680 Okay. So that completes the back propagation of the bathroom layer. And we are now here.
01:18:39.080 Let's back propagate through the linear layer one here. Now, because everything is getting a
01:18:43.880 little vertically crazy, I copy pasted the line here. And let's just back propagate through this
01:18:48.520 one line. So first, of course, we inspect the shapes. And we see that this is 32 by 64.
01:18:54.840 Emcad is 32 by 30. W one is 30 30 by 64. And B one is just 64. So as I mentioned, back
01:19:05.720 propagating through linear layers is fairly easy just by matching the shapes. So let's do that.
01:19:10.600 We have that D amp cat should be some matrix multiplication of D H prebn with w one and one
01:19:20.600 transpose thrown in there. So to make a amp cat be 32 by 30, I need to take D H prebn
01:19:32.040 32 by 64 and multiply it by W one dot transpose. To get D w one, I need to end up with 30 by 64.
01:19:44.680 So to get that, I need to take amp cat transpose and multiply that by
01:19:54.520 D H prebn. And finally, to get D B one, this is a addition. And we saw that basically,
01:20:05.000 I need to just sum the elements in D H prebn along some dimension. And to make the dimensions
01:20:11.240 work out, I need to sum along the zero access here to eliminate this dimension. And we do not keep
01:20:17.640 dims. So that we want to just get a single one dimensional vector of 64. So these are the claimed
01:20:24.920 derivatives. Let me put that here and let me uncommon three lines and cross our fingers.
01:20:32.600 Everything is great. Okay, so we now continue almost there. We have the derivative of amp cat
01:20:39.000 and we want to derivative, we want to back propagate into amp. So I again copied this line over here.
01:20:46.520 So this is the forward pass. And then this is the shapes. So remember that the shape here was 32
01:20:51.800 by 30. And the original shape of amp was 32 by three by 10. So this layer in the forward pass,
01:20:57.880 as you recall, did the concatenation of these three 10 dimensional character vectors. And so
01:21:04.440 now we just want to undo that. So this is actually relatively straightforward operation,
01:21:09.560 because the backward pass of the, what is the view view is just a repress representation of the
01:21:14.920 array. It's just a logical form of how you interpret the array. So let's just reinterpret it to be what
01:21:20.600 it was before. So in other words, the amp is not 32 by 30. It is basically the amp cat. But if you
01:21:30.920 view it as the original shape, so just m dot shape, you can pass in tuples into view. And so this
01:21:40.760 should just be, okay, we just rerepresent that view. And then we uncomment this line here. And hopefully,
01:21:50.200 yeah, so the derivative of m is correct. So in this case, we just have to rerepresent the shape of
01:21:57.000 those derivatives into the original view. So now we are at the final line. And the only thing
01:22:01.720 that's left to back propagate through is this indexing operation here, m is C at xb. So as I did
01:22:08.520 before, I copy pasted this line here. And let's look at the shapes of everything that's involved
01:22:12.600 and remind ourselves how this worked. So m dot shape was 32 by three by 10. So it's 32 examples.
01:22:21.240 And then we have three characters. Each one of them has a 10 dimensional embedding. And this was
01:22:27.560 achieved by taking the lookup table C, which have 27 possible characters, each of them 10 dimensional.
01:22:34.360 And we looked up at the rows that were specified inside this tensor xb. So xb is 32 by three. And
01:22:43.160 it's basically giving us for each example, the identity or the index of which character is part
01:22:49.000 of that example. And so here I'm showing the first five rows of three of this tensor xb. And so we
01:22:57.480 can see that for example, here it was the first example in this batch is that the first character
01:23:02.360 in the first character and the fourth character comes into the neural net. And then we want to
01:23:07.400 predict the next character in a sequence after the character is 114. So basically what's happening
01:23:13.160 here is there are integers inside xb. And each one of these integers is specifying which row of C
01:23:20.600 we want to pluck out, right? And then we arrange those rows that we've plucked out into three
01:23:27.960 two by three by 10 tensor. And we just package them in, we just package them into this tensor.
01:23:32.840 And now what's happening is that we have dm. So for every one of these, basically plucked out rows,
01:23:40.280 we have their gradients now, but they're arranged inside this 32 by three by 10 tensor.
01:23:45.880 So all we have to do now is we just need to route this gradient backwards through this assignment.
01:23:51.640 So we need to find which row of C that every one of these 10 dimensional embeddings come from.
01:23:58.120 And then we need to deposit them into DC. So we just need to undo the indexing. And of course,
01:24:06.200 if any of these rows of C was used multiple times, which almost certainly is the case,
01:24:10.840 like the row one and one was used multiple times, then we have to remember that the gradients
01:24:15.160 that arrive there have to add. So for each occurrence, we have to have an addition.
01:24:21.160 So let's now write this out. And I don't actually know of like a much better way to do this than a
01:24:25.400 for loop unfortunately in Python. So maybe someone can come up with a factorized efficient operation,
01:24:31.400 but for now let's just use for loops. So let me create a torch dot zeros like C to initialize
01:24:38.920 just 27 by 10 tensor of all zeros. And then honestly 4k in range, xb dot shape at zero.
01:24:49.480 Maybe someone has a better way to do this, but for J in range, xb dot shape at one.
01:24:54.280 This is going to iterate over all the, all the elements of XB, all these integers.
01:25:02.040 And then let's get the index at this position. So the index is basically xb at kj.
01:25:10.120 So that an example of that is 11 or 14 and so on. And now in the forward pass, we took
01:25:19.480 we basically took the row of C at index, and we deposited it into m at k a j. That's what
01:25:31.080 happened. That's where they are packaged. So now we need to go backwards and we just need to route
01:25:35.320 dm at the position kj. We now have these derivatives for each position and it's 10
01:25:44.600 dimensional. And you just need to go into the correct row of C. So DC rather at IX is this,
01:25:53.720 but plus equals, because there could be multiple occurrences, like the same row could have been
01:25:58.840 used many, many times. And so all of those derivatives will just go backwards through the indexing
01:26:05.640 and they will add. So this is my candidate solution. Let's copy it here.
01:26:14.040 Let's uncomment this and cross our fingers. Hey, so that's it. We've back propagated through
01:26:24.440 this entire beast. So there we go. Totally makes sense. So now we come to exercise two.
01:26:33.080 It basically turns out that in this first exercise, we were doing way too much work.
01:26:36.600 We were back propagating way too much. And it was all good practice and so on,
01:26:40.840 but it's not what you would do in practice. And the reason for that is, for example,
01:26:44.840 here I separated out this loss calculation over multiple lines, and I broke it up all
01:26:49.720 to like its smallest atomic pieces and we back propagated through all of those individually.
01:26:54.440 But it turns out that if you just look at the mathematical expression for the loss,
01:26:59.880 then actually you can do the differentiation on pen and paper and a lot of terms cancel and
01:27:05.080 simplify. And the mathematical expression you end up with can be significantly shorter and
01:27:09.560 easier to implement than back propagating through all the pieces of everything you've done.
01:27:13.080 So before we had this complicated forward pass going from logits to the loss,
01:27:18.440 but in PyTorch, everything can just be glued together into a single call f dot cross entropy.
01:27:23.880 You just pass in logits and the labels and you get the exact same loss as I verify here.
01:27:28.440 So our previous loss and the fast loss coming from the chunk of operations as a single mathematical
01:27:33.720 expression is the same, but it's much, much faster and forward pass. It's also much, much faster
01:27:39.800 and backward pass. And the reason for that is if you just look at the mathematical form of this
01:27:43.720 and differentiate again, you will end up with a very small and short expression. So that's what
01:27:48.280 we want to do here. We want to in a single operation or in a single go or like very quickly go directly
01:27:54.840 to delogits. And we need to implement delogits as a function of logits and YBs. But it will be
01:28:03.000 significantly shorter than whatever we did here, where to get to delogits, we have to go all the way
01:28:08.760 here. So all of this work can be skipped in a much, much simpler mathematical expression
01:28:14.440 that you can implement here. So you can give it a shot yourself, basically look at what exactly
01:28:21.400 is the mathematical expression of loss and differentiate with respect to the logits.
01:28:24.920 So let me show you a hint. You can of course try it fully yourself. But if not, I can give you some
01:28:32.520 hint of how to get started mathematically. So basically what's happening here is we have
01:28:38.680 logits, then there's the softmax that takes the logits and gives you probabilities. Then we are
01:28:43.880 using the identity of the correct next character to pluck out a row of probabilities, take the
01:28:49.880 negative log of it to get our negative log probability. And then we average up all the
01:28:54.920 log probabilities or negative log probabilities to get our loss. So basically what we have is for
01:29:01.160 a single individual example, rather, we have that loss is equal to negative log probability,
01:29:06.520 where P here is kind of like thought of as a vector of all the probabilities. So at the y
01:29:13.160 position, where y is the label. And we have that P here, of course, is the softmax. So the
01:29:21.160 i component of P of this probability vector is just the softmax function. So raising all the logits
01:29:28.520 basically to the power of E and normalizing. So everything comes to one. Now if you write out
01:29:36.440 P of y here, you can just write out the softmax. And then basically what we're interested in is
01:29:41.160 we're interested in the derivative of the loss with respect to the i-th logit.
01:29:46.360 And so basically it's a d by dli of this expression here, where we have l indexed with the specific
01:29:54.520 label y. And on the bottom we have a sum over j of e to the lj and the negative log of all that.
01:29:59.720 So potentially give it a shot pen and paper and see if you can actually derive the expression
01:30:04.760 for the loss by dli. And then we're going to implement it here. Okay, so I am going to give away
01:30:10.120 the result here. So this is some of the math I did to derive the gradients analytically. And so we
01:30:17.320 see here that I'm just applying the rules of calculus from your first or second year of bachelor's
01:30:21.240 degree if you took it. And we see that the expression is actually simplified quite a bit. You have to
01:30:26.280 separate out the analysis in the case where the i-th index that you're interested in inside
01:30:30.840 logits is either equal to the label or it's not equal to the label. And then the expression is
01:30:35.560 simplified and canceled in a slightly different way. And what we end up with is something very,
01:30:39.960 very simple. We either end up with basically p at i where p is again this vector of probabilities
01:30:47.160 after a softmax or p at i minus one, where we just simply subtract to one. But in any case we
01:30:52.920 just need to calculate the softmax p and then in the correct dimension we need to subtract to one.
01:30:58.920 And that's the gradient, the form that it takes analytically. So let's implement this basically.
01:31:03.960 And we have to keep in mind that this is only done for a single example. But here we are working
01:31:07.960 with batches of examples. So we have to be careful of that. And then the loss for a batch is the
01:31:14.280 average loss over all the examples. So in other words, it's the example for all the individual
01:31:18.520 examples is the loss for each individual example summed up and then divided by n. And we have to
01:31:24.280 back propagate through that as well and be careful with it. So d logits is going to be f dot softmax.
01:31:32.920 Plattor has a softmax function that you can call. And we want to apply the softmax on the logits
01:31:38.120 and we want to go in the dimension that is one. So basically we want to do the softmax along the
01:31:44.440 rows of these logits. Then at the correct positions we need to subtract a one. So d logits at
01:31:51.720 iterating over all the rows and indexing into the columns provided by the correct labels inside yb.
01:32:00.360 We need to subtract one. And then finally, it's the average loss that is the loss. And in the
01:32:06.680 average there's a one over n of all the losses added up. And so we need to also back propagate
01:32:12.440 through that division. So the gradient has to be scaled down by n as well, because of the mean.
01:32:18.840 But this otherwise should be the result. So now if we verify this, we see that we don't get an
01:32:26.040 exact match. But at the same time, the maximum difference from logits from PyTorch and our d
01:32:33.720 logits here is on the order of 5e negative nine. So it's a tiny, tiny number. So because of loading
01:32:40.520 point of onkiness, we don't get the exact bitwise result. But we basically get the correct answer
01:32:46.040 approximately. Now I'd like to pause here briefly before we move on to the next exercise,
01:32:52.680 because I'd like us to get an intuitive sense of what the logits is, because it has a beautiful
01:32:57.080 and very simple explanation, honestly. So here I'm taking the logits and I'm visualizing it.
01:33:04.200 And we can see that we have a batch of 32 examples of 27 characters. And what is the logits
01:33:09.800 intuitively, right? The logits is the probabilities that the probabilities matrix in a forward pass.
01:33:15.800 But then here, these black squares are the positions of the correct indices where we subtracted a one.
01:33:21.720 And so what is this doing, right? These are the derivatives on the logits. And so let's look at
01:33:28.680 just the first row here. So that's what I'm doing here. I'm calculating the probabilities of these
01:33:34.360 logits, and then I'm taking just the first row. And this is the probability row. And then the
01:33:39.320 logits of the first row and multiplying by n just for us so that we don't have the scaling
01:33:45.320 by n in here, and everything is more interpretable. We see that it's exactly equal to the probability,
01:33:50.440 of course, but then the position of the correct index has a minus equals one. So minus one on that
01:33:56.120 position. And so notice that if you take the logits at zero and you sum it, it actually sums
01:34:04.280 to zero. And so you should think of these gradients here at each cell as like a force. We are going
01:34:13.880 to be basically pulling down on the probabilities of the incorrect characters. And we're going to be
01:34:19.400 pulling up on the probability at the correct index. And that's what's basically happening
01:34:25.240 in each row. And the amount of push and pull is exactly equalized because the sum is zero. So the
01:34:34.360 amount to which we pulled down on the probabilities and the amount that we push up on the probability
01:34:39.000 of the correct character is equal. So it's sort of the repulsion and attraction are equal. And think
01:34:45.240 of the neural map now as a, as a like a massive poly system or something like that, we're up here
01:34:50.840 on top of the logits, and we're pulling up, we're pulling down the probabilities are incorrect and
01:34:55.080 pulling up the property of the correct. And in this complicated poly system, because everything is
01:34:59.640 mathematically just determined, just think of it as sort of like this tension translating to
01:35:05.000 this complicating pulling mechanism. And then eventually we get a tug on the weights and the
01:35:09.400 biases. And basically in each update, we just kind of like tug in the direction that we like
01:35:14.280 for each of these elements. And the parameters are slowly given in to the tug. And that's what
01:35:18.840 training in neural net kind of like looks like on a high level. And so I think the forces of push
01:35:24.440 and pull in these gradients are actually very intuitive here. We're pushing and pulling on the
01:35:29.960 correct answer and the incorrect answers. And the amount of force that we're applying is actually
01:35:34.600 proportional to the probabilities that came out in the forward pass. And so for example, if our
01:35:40.600 probabilities came out exactly correct, so they would have had zero everywhere except for one at
01:35:45.640 the correct position, then the the logits would be all row of zeros for that example. There would
01:35:52.040 be no push and pull. So the amount to which your prediction is incorrect is exactly the amount
01:35:58.120 by which you're going to get a pull or push in that dimension. So if you have for example, a
01:36:03.480 very confidently mispredicted element here, then what's going to happen is that element is going
01:36:08.520 to be pulled down very heavily. And the correct answer is going to be pulled up to the same amount.
01:36:14.200 And the other characters are not going to be influenced too much. So the amount to which you
01:36:20.280 mis-predict is then proportional to the strength of the pull. And that's happening independently in
01:36:26.040 all the dimensions of this of this tensor. And it's sort of very intuitive and very easy to think
01:36:30.680 through. And that's basically the magic of the cross entropy loss and what is doing dynamically
01:36:35.720 in the backward pass of the neural mat. So now we get to exercise number three,
01:36:39.480 which is a very fun exercise, depending on your definition of fun. And we are going to do for
01:36:45.160 batch normalization exactly what we did for cross entropy loss in exercise number two. That is,
01:36:50.040 we are going to consider it as a glued single mathematical expression and back propagate through
01:36:54.440 it in a very efficient manner, because we are going to derive a much simpler formula for the
01:36:58.760 backward pass of batch normalization. And we're going to do that using pen and paper. So previously,
01:37:04.280 we've broken up batch normalization into all of the little intermediate pieces and all the atomic
01:37:08.200 operations inside it. And then we back propagate it through it one by one. Now we just have a single
01:37:15.000 sort of forward pass of a batch room. And it's all glued together. And we see that we get these
01:37:21.240 as same result as before. Now for the batch backward pass, we'd like to also implement
01:37:26.280 a single formula basically for back propagating through this entire operation. That is the
01:37:30.520 batch normalization. So in the forward pass previously, we took H prebn, the hidden states of the
01:37:37.240 pre-bacterialization and created H preact, which is the hidden states just before the activation.
01:37:43.160 In the batch normalization paper, H prebn is X and H preact is Y. So in the backward pass,
01:37:50.440 what we'd like to do now is we have D H preact, and we'd like to produce D H prebn.
01:37:56.520 And we'd like to do that in a very efficient manner. So that's the name of the game,
01:38:00.280 calculate the H prebn given the H preact. And for the purposes of this exercise, we're going to
01:38:06.200 ignore gamma and beta and their derivatives, because they take on a very simple form in a very similar
01:38:11.640 way to what we did up above. So let's calculate this, given that right here. So to help you a
01:38:19.560 little bit like I did before, I started off the implementation here on pen and paper. And I took
01:38:26.280 two sheets of paper to derive the mathematical formulas for the backward pass. And basically,
01:38:31.320 to set up the problem, just write out the mu sigma square variance, X i hat and Y i exactly as in the
01:38:39.080 paper, except for the Bessel correction. And then in the backward pass, we have the derivative of
01:38:45.000 the loss with respect to all the elements of Y. And remember that Y is a vector, there's there's
01:38:50.120 multiple numbers here. So we have all the derivatives of the spectral order, Ys,
01:38:56.440 and then there's a dema and a beta. And this is kind of like the compute graph. The gamma and
01:39:01.480 the beta, there's the X hat, and then the mu and the sigma square and the X. So we have dl by dy,
01:39:09.400 and we won't dl by dxi for all the i's in these vectors. So this is the compute graph. And you have
01:39:17.400 to be careful because I'm trying to note here that these are vectors. There's many nodes here inside
01:39:23.960 X, X hat and Y, but mu and sigma, sorry, sigma square are just individual scalars, single numbers.
01:39:31.800 So you have to be careful with that. You have to imagine there's multiple nodes here, or you're
01:39:35.480 going to get your math wrong. So as an example, I would suggest that you go in the following order,
01:39:41.800 one, two, three, four, in terms of the back propagation. So back propagating to X hat,
01:39:46.840 then to sigma square, then into mu, and then into X. Just like an entopological sort in
01:39:53.800 micrograd, we would go from right to left. You're doing the exact same thing, except you're doing
01:39:57.800 it with symbols and on a piece of paper. So for number one, I'm not giving away too much. If you
01:40:06.200 want dl of the X hat, then I would just take dl by dy and multiply by gamma because of this
01:40:14.440 expression here, where any individual yi is just gamma times X i hat plus beta. So it doesn't help
01:40:21.800 you too much there, but this gives you basically the derivatives for all the X hats. And so now,
01:40:27.560 try to go through this computational graph and derive what is dl by d sigma square,
01:40:34.040 and then what is dl by d mu, and then what is dl by dx eventually. So give it a go, and I'm going
01:40:41.960 to be revealing the answer one piece at a time. Okay, so to get dl by d sigma square, we have to
01:40:47.480 remember again, like I mentioned, that there are many Xs X hats here. And remember that sigma
01:40:53.400 square is just a single individual number here. So when we look at the expression
01:40:57.880 for dl by d sigma square, we have that we have to actually consider all the possible paths that
01:41:05.720 we basically have that there's many X hats, and they all feed off from the all depend on sigma
01:41:13.640 square. So sigma square has a large fan out, there's lots of arrows coming out from sigma square into
01:41:18.920 all the X hats. And then there's a back propagating signal from each X hat into sigma square. And
01:41:25.400 that's why we actually need to sum over all those eyes from i equal to one to m of the dl by d
01:41:32.680 X i hat, which is the global gradient times the X i hat by d sigma square, which is the local gradient
01:41:42.120 of this operation here. And then mathematically, I'm just working it out here, and I'm simplifying
01:41:47.880 and you get a certain expression for dl by d sigma square. We're going to be using this
01:41:52.520 expression when we back propagate into mu and then eventually into X. So now let's continue our
01:41:57.240 back propagation into mu. So what is dl by d mu? Now again, be careful that mu influences X hat
01:42:04.280 and X hat is actually lots of values. So for example, if our mini batch size is 32, as it is in our
01:42:09.800 example that we were working on, then this is 32 numbers and 32 arrows going back to mu.
01:42:15.080 And then mu going to sigma square is just a single arrow because sigma square is a scalar. So in
01:42:20.120 total, there are 33 arrows emanating from you. And then all of them have gradients coming into mu,
01:42:26.920 and they all need to be summed up. And so that's why when we look at the expression for dl by d mu,
01:42:33.400 I am summing up over all the gradients of dl by d X hat times d X hat by being mu. So that's the
01:42:41.320 that's this arrow and the 32 arrows here. And then plus the one arrow from here, which is dl by d
01:42:47.320 sigma square times d sigma square by d mu. So now we have to work out that expression. And let
01:42:53.000 me just reveal the rest of it. Simplifying here is not complicated, the first term, and you just
01:42:59.400 get an expression here. For the second term though, there's something really interesting that
01:43:03.160 happens. When we look at d sigma square by d mu, and we simplify, at one point, if we assume that
01:43:11.160 in a special case where mu is actually the average of X i's, as it is in this case, then if we plug
01:43:18.840 that in, then actually the gradient vanishes and becomes exactly zero. And that makes the entire
01:43:24.520 second term cancel. And so these, if you just have a mathematical expression like this, and you look
01:43:30.680 at d sigma square by d mu, you would get some mathematical formula for how mu impacts sigma square.
01:43:37.000 But if it is the special case that mu is actually equal to the average, as it is in the case of
01:43:42.200 rationalization, that gradient will actually vanish and become zero. So the whole term cancels,
01:43:47.480 and we just get a fairly straightforward expression here for dl by d mu. Okay, and now we get to the
01:43:53.160 craziest part, which is deriving dl by d X i, which is ultimately what we're after. Now, let's count
01:44:00.920 first of all, how many numbers are there inside X? As I mentioned, there are 32 numbers. There
01:44:05.560 are 32 little X i's. And let's count the number of arrows emanating from each X i. There's an
01:44:11.400 arrow going to mu, an arrow going to sigma square, and then there's an arrow going to X hat. But
01:44:17.080 this arrow here, let's group now that a little bit. Each X i hat is just a function of X i and all
01:44:24.200 the other scalars. So X i hat only depends on X i and none of the other X's. And so therefore,
01:44:31.400 there are actually in this single arrow, there are 32 arrows, but those 32 arrows are going
01:44:36.600 exactly parallel. They don't interfere. They're just going parallel between X and X hat. You can
01:44:41.560 look at it that way. And so how many arrows are emanating from each X i, there are three arrows,
01:44:46.600 mu sigma square, and the associated X hat. And so in back propagation, we now need to apply the chain
01:44:53.640 rule. And we need to add up those three contributions. So here's what that looks like. If I just write
01:44:59.720 that out, we have, we're going through, we're changing through mu sigma square and through X hat.
01:45:07.560 And those three terms are just here. Now we already have three of these. We have dl by d X i hat.
01:45:14.760 We have dl by d mu, which we derived here. And we have dl by d sigma square, which we derived here.
01:45:21.080 But we need three other terms here. This one, this one, and this one. So I invite you to try to
01:45:27.480 derive them. It's not that complicated. You're just looking at these expressions here and
01:45:31.400 differentiating with respect to X i. So give it a shot, but here's the result. Or at least what I got.
01:45:40.280 Yeah, I'm just, I'm just differentiating with respect to X i for all these expressions. And
01:45:46.200 honestly, I don't think there's anything too tricky here. It's basic calculus.
01:45:50.520 Now it gets a little bit more tricky is we are now going to plug everything together.
01:45:54.600 So all of these terms multiplied with all of these terms and added up according to this formula.
01:45:59.160 And that gets a little bit hairy. So what ends up happening is,
01:46:02.280 you get a large expression. And the thing to be very careful with here, of course, is we are
01:46:10.600 working with a dl by d X i for specific i here. But when we are plugging in some of these terms,
01:46:17.400 like say, this term here, dl by d sigma squared, you see how dl by d sigma squared, I end up with
01:46:24.840 an expression. And I'm iterating over little i's here. But I can't use i as the variable when I
01:46:31.480 plug in here, because this is a different i from this i. This i here is just a place
01:46:36.440 out a local variable for for a for loop in here. So here, when I plug that in, you notice that I
01:46:42.040 rename the i to a j, because I need to make sure that this j is not that this j is not this i.
01:46:48.040 This j is like a little local iterator over 32 terms. And so you have to be careful with that.
01:46:53.560 When you're plugging in the expressions from here to here, you may have to rename i's into j's.
01:46:57.960 You have to be very careful what is actually an i with respect to dl by d X i. So some of these are
01:47:05.320 j's. Some of these are i's. And then we simplify this expression. And I guess like the big thing to
01:47:13.240 notice here is a bunch of terms just going to come out to the front and you can refactor them.
01:47:17.640 There's a sigma squared plus epsilon raised to the power of negative three over two. This sigma
01:47:21.720 squared plus epsilon can be actually separated out into three terms. Each of them are sigma
01:47:26.840 squared plus epsilon to the negative one over two. So the three of them multiplied is equal to this.
01:47:33.400 And then those three terms can go different places because of the multiplication. So one of them
01:47:38.040 actually comes out to the front and will end up here outside. One of them joins up with this term
01:47:44.840 and one of them joins up with this other term. And then when you simplify the expression,
01:47:49.160 you'll notice that some of these terms that are coming out are just the X i hats. So you can
01:47:54.600 simplify just by rewriting that. And what we end up with at the end is a fairly simple mathematical
01:47:59.800 expression over here that I cannot simplify further. But basically, you'll notice that
01:48:04.520 it only uses the stuff we have and it derives the thing we need. So we have dl by dy for all the
01:48:11.240 i's. And those are used plenty of times here. And also in the initial, what we're using is these
01:48:16.840 X i hats and X j hats. And they just come from the forward pass. And otherwise, this is a
01:48:22.440 simple expression and it gives us dl by d xi for all the i's. And that's ultimately what we're
01:48:28.040 interested in. So that's the end of a Batchnorm backward pass analytically. Let's now implement
01:48:35.000 this final result. Okay, so I implemented the expression into a single line of code here. And
01:48:40.600 you can see that the max diff is tiny. So this is the correct implementation of this formula.
01:48:45.320 Now, I'll just basically tell you that getting this formula here from this mathematical expression
01:48:53.080 was not trivial. And there's a lot going on packed into this one formula. And this is a whole
01:48:57.480 exercise by itself, because you have to consider the fact that this formula here is just for a
01:49:02.920 single neuron and a batch of 32 examples. But what I'm doing here is I'm actually, we actually
01:49:08.280 have 64 neurons. And so this expression has to imperilil evaluate the Batchnorm backward pass for
01:49:14.040 all those 64 neurons in parallel independently. So this has to happen basically in every single
01:49:19.240 column of the inputs here. And in addition to that, you see how there are a bunch of sums
01:49:27.240 here. And we need to make sure that when I do those sums, that they broadcast correctly onto
01:49:31.480 everything else that's here. And so getting this expression is just like highly non trivial. And
01:49:36.040 I invite you to basically look through it and step through it. And it's a whole exercise to make
01:49:39.240 sure that this checks out. But once all the shapes agree, and once you convince yourself that it's
01:49:46.040 correct, you can also verify that PyTorch gets the exact same answer as well. And so that gives
01:49:50.680 you a lot of peace of mind that this mathematical formula is correctly implemented here and broadcast
01:49:55.960 it correctly and replicated in parallel for all the 64 neurons inside this Batchnorm layer.
01:50:02.040 Okay, and finally exercise number four asks you to put it all together. And here we have a
01:50:08.040 redefinition of the entire problem. So you see that we reinstallize the neural net from scratch and
01:50:12.280 everything. And then here, instead of calling a loss that backward, we want to have the manual
01:50:17.880 back propagation here as we derived it up above. So go up, copy paste all the chunks of code that
01:50:23.160 we've already derived, put them here and drive your own gradients, and then optimize this neural
01:50:28.360 net, basically using your own gradients all the way to the calibration of the Batchnorm and the
01:50:33.560 evaluation of the loss. And I was able to achieve quite a good loss, basically the same loss you
01:50:37.960 would achieve before. And that shouldn't be surprising, because all we've done is we've
01:50:42.040 really got into loss that backward, and we've pulled out all the code and inserted it here.
01:50:47.960 But those gradients are identical, and everything is identical, and the results are identical. It's
01:50:52.680 just that we have full visibility on exactly what goes on under the hood of lot that backward
01:50:57.640 in this specific case. Okay, and this is all of our code. This is the full backward pass,
01:51:02.840 using basically the simplified backward pass for the cross entropy loss and the Batchnormization.
01:51:08.680 So back propagating through cross entropy, the second layer, the 10-H neural nett,
01:51:14.520 the Batchnormization through the first layer, and through the embedding. And so you see that this
01:51:20.840 is only maybe what is this 20 lines of code or something like that, and that's what gives us
01:51:25.400 gradients. And now we can potentially erase loss as backward. So the way I have the code set up is
01:51:31.480 you should be able to run this entire cell once you fill this in. And this will run for only 100
01:51:36.200 iterations and then break. And it breaks because it gives you an opportunity to check your gradients
01:51:40.920 against pytorch. So here, our gradients we see are not exactly equal. They are approximately equal,
01:51:49.400 and the differences are tiny, one in negative nine or so. And I don't exactly know where they're
01:51:53.640 coming from, to be honest. So once we have some confidence that the gradients are basically correct,
01:51:58.360 we can take out the gradient checking. We can disable this breaking statement.
01:52:04.440 And then we can basically disable loss of backward. We don't need it anywhere. Feels amazing to say that.
01:52:13.400 And then here, when we are doing the update, we're not going to use p.grad.
01:52:18.520 This is the old way of pytorch. We don't have that anymore, because we're not doing backward.
01:52:23.400 We are going to use this update, where we you see that I'm iterating over, I've arranged the grads
01:52:30.120 to be in the same order as the parameters, and I'm zipping them up, the gradients and the parameters
01:52:35.160 into p and grad. And then here, I'm going to step with just a grad that we derived manually.
01:52:40.920 So the last piece is that none of this now requires gradients from pytorch.
01:52:48.360 And so one thing you can do here is you can do with torch.no grad and offset this whole code block.
01:52:57.160 And really what you're saying is you're telling pytorch that, hey, I'm not going to call backward
01:53:01.160 on any of this. And this last pytorch to be a bit more efficient with all of it. And then we
01:53:06.040 should be able to just run this. And it's running. And you see that loss of the backward is commented
01:53:16.920 out, and we're optimizing. So we're going to leave this run, and hopefully we get a good result.
01:53:24.440 Okay, so I allowed the neural out to finish optimization. Then here, I calibrate the
01:53:30.760 fashion parameters, because I did not keep track of the running mean and very, variance in their
01:53:35.880 training loop. Then here, I ran the loss. And you see that we actually obtained a pretty good loss,
01:53:41.240 very similar to what we achieved before. And then here, I'm sampling from the model,
01:53:45.720 and we see some of the name like gibberish that we're sort of used to. So basically, the model
01:53:50.200 worked and samples, pretty decent results, compared to what we're used to. So everything is the same.
01:53:56.760 But of course, the big deal is that we did not use lots of backward. We did not use pytorch autograd,
01:54:01.720 and we estimated our gradients ourselves by hand. And so hopefully you're looking at this,
01:54:06.600 the backward pass of this neural net, and you're thinking to yourself, actually, that's not too
01:54:11.000 complicated. Each one of these layers is like three lines of code or something like that. And
01:54:17.000 most of it is fairly straightforward, potentially with the notable exception of the batch normalization
01:54:22.040 backward pass. Otherwise, it's pretty good. Okay, and that's everything I wanted to cover for this
01:54:26.600 lecture. So hopefully you found this interesting. And what I liked about it honestly, is that it
01:54:32.040 gave us a very nice diversity of layers to back propagate through. And I think it gives a pretty
01:54:37.800 nice and comprehensive sense of how these backward passes are implemented and how they work. And
01:54:42.680 you'd be able to derive them yourself. But of course, in practice, you probably don't want to,
01:54:46.200 and you want to use the pytorch autograd. But hopefully you have some intuition about how
01:54:50.280 gradients flow backwards through the neural net, starting at the loss, and how they flow through
01:54:55.160 all the variables and all the intermediate results. And if you understood a good chunk of it, and if
01:55:00.680 you have a sense of that, then you can count yourself as one of these buff doggies on the left,
01:55:04.680 instead of the doggies on the right here. Now, in the next lecture, we're actually going to go to
01:55:10.360 recurring neural nets, LSTMs, and all the other variants of Arnaz. And we're going to start to
01:55:16.200 complexify the architecture and start to achieve better log likelihoods. And so I'm really looking