BackPACK

Get more out of your backward pass

BackPACK is a library built on top of PyTorch to make it easy to extract more information from a backward pass. Some of the things you can compute:

"""
Compute the gradient with Pytorch

"""
from torch.nn import CrossEntropyLoss, Linear
from utils import load_mnist_data


X, y = load_mnist_data()
model = Linear(764, 10)
lossfunc = CrossEntropyLoss()
loss = lossfunc(model(X), y)


loss.backward()

for param in model.parameters():
    print(param.grad)