DEV Community

Kyle Pena
Kyle Pena

Posted on • Edited on

Visualizing Decoder Layer Gradients

In this short post, I'll explain a practical problem you'll encounter when visualizing the gradients of decoder layers, and how to resolve it.

The Llama 3.2-1b model consists of a token input embedding layer, 15 stacked decoder layers, followed by a token prediction head.

Each decoder layer takes as input a hidden state tensor of dimension (B,N,2048), where B is the batch dimension, N is the number of tokens, and 2048 is the model dimension. Therefore H[0,1,10] represents the activation of the 11th "neuron" of the second token in a batch size of one.

PyTorch's autograd allows us to compute the partial differential of any activation (call it α\alpha ) with respect to some earlier layer of the network - say, for layer j, HjH^j :

tokenized = tokenizer(input_text, return_tensors="pt", padding=True).to(device)
result = model(**tokenized, output_hidden_states=True)
hidden_states = result.hidden_states
gradient = torch.autograd.grad(
    outputs=alpha,
    inputs=hidden_states[j],
)[0]
Enter fullscreen mode Exit fullscreen mode

torch.autograd.grad only allows for the computation with respect to a single activation, so for the complete picture we would need to loop over every activation in the next layer Hj+1H^{j+1} :

gradients = {}
token_idx = ...
for i in range(2048):
    alpha = hidden_states[j+1][:, token_idx, i]
    gradients[i] = torch.autograd.grad(
        outputs=alpha,
        inputs=hidden_states[j],
    )[0]
Enter fullscreen mode Exit fullscreen mode

As a heatmap, the result looks something like this (this is called the Jacobian):

Notice the strong diagonal? This is not a bug. It's because of skip connections.

Skip Connections

Skip connections are ubiquitous in LLM architectures, as they are one of the primary reasons that it is possible to train very deep networks. If we are fitting a layer f(x)=H(x)+xf(x) = H(x) + x , where H(x) is the main function (i.e.; self-attention), then H(x)H(x) can be interpreted as the residual f(x)xf(x) - x . Thus each layer is responsible for learning an additive update to the input x instead of a wholesale transformation. This makes deep networks easier to train, as in some sense fitting the residual is "easier". It also appears to have some kind of strange smoothing effect on the loss landscape that is no doubt related. You can read the ResNet paper for more details.

However, this is problematic for our purposes because we can't read the "real" gradients along the diagonal - they are obscured by the overwhelming effect of the skip connection.

The Fix

This is tricky to fix if we take the decoder layer as a whole. However, if we focus on the self-attention layer (or the fully connected layer) within the decoder block, it's very manageable.

Let's take a look at this (slightly simplified) implementation of the decoder layer.

        # Self Attention
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
        hidden_states, self_attn_weights = self.self_attn(...)
        hidden_states = residual + hidden_states # Skip connection

        # Fully Connected
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states # Skip connection

        return hidden_states
Enter fullscreen mode Exit fullscreen mode

Notice that the implementation is cleanly separated into two "blocks": Self-Attention and Fully Connected.

To make it a bit clearer, we might write it like this:

    def forward(self, hidden_states):

        residual = hidden_states
        hidden_states = self.self_attention(hidden_states)
        hidden_states = residual + hidden_states # <- Skip Connection


        residual = hidden_states
        hidden_states = self.fully_connected(hidden_states)
        hidden_states = residual + hidden_states # <- Skip Connection

        return hidden_states
Enter fullscreen mode Exit fullscreen mode

First, we'll need a few hook helper methods to store a reference to the relevant tensors.


layer_ins = {}
layer_outs = {}

def _save_layer_output(tag):
    def hook(_mod, args, kwargs, out):
        logging.debug(f"Saving output from {tag}, type: {type(out)}")
        if isinstance(out, tuple):
            out = out[0]
        layer_outs[tag] = out
    return hook

def _save_layer_input(tag):
    def hook(_mod, args, kwargs):
        logging.debug(f"Saving input to {tag}, type: {type(args)}")
        if isinstance(args, tuple):
            args = args[0]
        layer_ins[tag] = args
    return hook
Enter fullscreen mode Exit fullscreen mode

Then, we can register the hooks. register_forward_pre_hook captures the input to an operation and register_forward_hook captures the output. If you refer back to the complete decoder layer forward pass implementation, input_layernorm and post_attention_layernorm are the cutpoints in the computation graph we need to isolate the self-attention portion of the decoder layer.

def setup_hooks_attn():
    for idx, layer in enumerate(model.model.layers):
           layer.input_layernorm.register_forward_pre_hook(_save_layer_input(f"input_layernorm_{idx}"))
        layer.post_attention_layernorm.register_forward_hook(_save_layer_output(f"post_attention_layernorm_{idx}"))
Enter fullscreen mode Exit fullscreen mode

The gradient visualization for self-attention still shows the same diagonal dominance phenomena:

But it disappears if we simply subtract the identity matrix I (note the updated colormap scale).

Why It Works

The explanation is simple. The self-attention layer with skip-connection f(x)=H(x)+xf(x) = H(x) + x has a Jacobian (read: partial differential or autograd gradient) with respect to x of Jf=JH(x)+IJ_f = J_{H(x)} + I .

Therefore, the partial differential of just the main branch H(x)H(x) is JH(x)=JfIJ_{H(x)} = J_f - I . In plain english, this simply means we should subtract the identity matrix from the gradient matrix.

This fix works equally well for the fully connected portion of the decoder layer.

Top comments (0)