import torch.nn.functional as F,matplotlib as mpl
from pathlib import Path
from operator import attrgetter,itemgetter
from contextlib import contextmanager
from torch import tensor,nn,optim
import torchvision.transforms.functional as TF
from datasets import load_dataset
from fastcore.test import test_close
=2, linewidth=140, sci_mode=False)
torch.set_printoptions(precision'figure.constrained_layout.use'] = True
mpl.rcParams[
import logging
logging.disable(logging.WARNING)
This is not my content it’s a part of Fastai’s From Deep Learning Foundations to Stable Diffusion course. I add some notes for me to understand better thats all. For the source check Fastai course page.
Activations, hooks and histograms.
Activation stats
::: {.cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
from __future__ import annotations
import random,math,torch,numpy as np,matplotlib.pyplot as plt
import fastcore.all as fc
from functools import partial
from miniai.datasets import *
from miniai.learner import *
:::
::: {.cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
def set_seed(seed, deterministic=False):
torch.use_deterministic_algorithms(deterministic)
torch.manual_seed(seed)
random.seed(seed) np.random.seed(seed)
:::
= 'image','label'
x,y = "fashion_mnist"
name = load_dataset(name)
dsd = 1024
bs
@inplace
def transformi(b): b[x] = [TF.to_tensor(o) for o in b[x]]
= dsd.with_transform(transformi)
tds = DataLoaders.from_dd(tds, bs, num_workers=4)
dls = dls.train dt
Baseline
def conv(ni, nf, ks=3, act=True):
= nn.Conv2d(ni, nf, stride=2, kernel_size=ks, padding=ks//2)
res if act: res = nn.Sequential(res, nn.ReLU())
return res
def cnn_layers():
return [
1 ,8, ks=5), #14x14
conv(8 ,16), #7x7
conv(16,32), #4x4
conv(32,64), #2x2
conv(64,10, act=False), #1x1
conv( nn.Flatten()]
We want to train quickly, so that means training at a high learning rate.
from torcheval.metrics import MulticlassAccuracy
= MetricsCB(accuracy=MulticlassAccuracy())
metrics = [TrainCB(), DeviceCB(), metrics, ProgressCB(plot=True)] cbs
def fit(model, epochs=1, xtra_cbs=None):
= Learner(model, dls, loss_func=F.cross_entropy, lr=0.6, cbs=cbs+fc.L(xtra_cbs))
learn
learn.fit(epochs)return learn
1)
set_seed(= fit(nn.Sequential(*cnn_layers())) learn
accuracy | loss | epoch | train |
---|---|---|---|
0.208 | 2.243 | 0 | train |
0.204 | 2.165 | 0 | eval |
Hooks
Manual insertion
class SequentialModel(nn.Module):
def __init__(self, *layers):
super().__init__()
self.layers = nn.ModuleList(layers)
self.act_means = [[] for _ in layers]
self.act_stds = [[] for _ in layers]
def __call__(self, x):
for i,l in enumerate(self.layers):
= l(x)
x self.act_means[i].append(to_cpu(x).mean())
self.act_stds [i].append(to_cpu(x).std ())
return x
def __iter__(self): return iter(self.layers)
1)
set_seed(= SequentialModel(*cnn_layers())
model = fit(model) learn
accuracy | loss | epoch | train |
---|---|---|---|
0.247 | 2.110 | 0 | train |
0.376 | 1.659 | 0 | eval |
for l in model.act_means: plt.plot(l)
range(5)); plt.legend(
for l in model.act_stds: plt.plot(l)
range(5)); plt.legend(
Pytorch hooks
Hooks are PyTorch object you can add to any nn.Module. A hook will be called when a layer, it is registered to, is executed during the forward pass (forward hook) or the backward pass (backward hook). Hooks don’t require us to rewrite the model.
1)
set_seed(= nn.Sequential(*cnn_layers()) model
A hook is attached to a layer, and needs to have a function that takes three arguments: module, input, output. Here we store the mean and std of the output in the correct position of our list.
= [[] for _ in model]
act_means = [[] for _ in model] act_stds
def append_stats(i, mod, inp, outp):
act_means[i].append(to_cpu(outp).mean()) act_stds [i].append(to_cpu(outp).std())
for i,m in enumerate(model): m.register_forward_hook(partial(append_stats, i))
fit(model)
accuracy | loss | epoch | train |
---|---|---|---|
0.263 | 2.091 | 0 | train |
0.164 | 2.245 | 0 | eval |
<miniai.learner.Learner>
for o in act_means: plt.plot(o)
range(5)); plt.legend(
Hook class
We can refactor this in a Hook class. It’s very important to remove the hooks when they are deleted, otherwise there will be references kept and the memory won’t be properly released when your model is deleted.
::: {.cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
class Hook():
def __init__(self, m, f): self.hook = m.register_forward_hook(partial(f, self))
def remove(self): self.hook.remove()
def __del__(self): self.remove()
:::
def append_stats(hook, mod, inp, outp):
if not hasattr(hook,'stats'): hook.stats = ([],[])
= to_cpu(outp)
acts 0].append(acts.mean())
hook.stats[1].append(acts.std()) hook.stats[
1)
set_seed(= nn.Sequential(*cnn_layers()) model
= [Hook(l, append_stats) for l in model[:5].children()] hooks
= fit(model) learn
accuracy | loss | epoch | train |
---|---|---|---|
0.247 | 2.122 | 0 | train |
0.397 | 1.456 | 0 | eval |
for h in hooks:
0])
plt.plot(h.stats[
h.remove()range(5)); plt.legend(
A Hooks class
class DummyCtxMgr:
def __enter__(self, *args):
print("let's go!")
return self
def __exit__ (self, *args): print("all done!")
def hello(self): print("hello.")
with DummyCtxMgr() as dcm: dcm.hello()
let's go!
hello.
all done!
class DummyList(list):
def __delitem__(self, i):
print(f"Say bye to item {i}")
super().__delitem__(i)
= DummyList([1,3,2])
dml dml
[1, 3, 2]
del(dml[2])
dml
Say bye to item 2
[1, 3]
::: {.cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
class Hooks(list):
def __init__(self, ms, f): super().__init__([Hook(m, f) for m in ms])
def __enter__(self, *args): return self
def __exit__ (self, *args): self.remove()
def __del__(self): self.remove()
def __delitem__(self, i):
self[i].remove()
super().__delitem__(i)
def remove(self):
for h in self: h.remove()
:::
1)
set_seed(= nn.Sequential(*cnn_layers()) model
with Hooks(model, append_stats) as hooks:
fit(model)= plt.subplots(1,2, figsize=(10,4))
fig,axs for h in hooks:
for i in 0,1: axs[i].plot(h.stats[i])
range(6)); plt.legend(
accuracy | loss | epoch | train |
---|---|---|---|
0.166 | 2.475 | 0 | train |
0.100 | 2.303 | 0 | eval |
HooksCallback
::: {.cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
class HooksCallback(Callback):
def __init__(self, hookfunc, mod_filter=fc.noop, on_train=True, on_valid=False, mods=None):
fc.store_attr()super().__init__()
def before_fit(self, learn):
if self.mods: mods=self.mods
else: mods = fc.filter_ex(learn.model.modules(), self.mod_filter)
self.hooks = Hooks(mods, partial(self._hookfunc, learn))
def _hookfunc(self, learn, *args, **kwargs):
if (self.on_train and learn.training) or (self.on_valid and not learn.training): self.hookfunc(*args, **kwargs)
def after_fit(self, learn): self.hooks.remove()
def __iter__(self): return iter(self.hooks)
def __len__(self): return len(self.hooks)
:::
= HooksCallback(append_stats, mod_filter=fc.risinstance(nn.Conv2d)) hc
1)
set_seed(= nn.Sequential(*cnn_layers())
model =[hc]); fit(model, xtra_cbs
accuracy | loss | epoch | train |
---|---|---|---|
0.205 | 2.232 | 0 | train |
0.100 | 2.305 | 0 | eval |
= plt.subplots(1,2, figsize=(10,4))
fig,axs for h in hc:
for i in 0,1: axs[i].plot(h.stats[i])
range(6)); plt.legend(
Histograms
::: {.cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
def append_stats(hook, mod, inp, outp):
if not hasattr(hook,'stats'): hook.stats = ([],[],[])
= to_cpu(outp)
acts 0].append(acts.mean())
hook.stats[1].append(acts.std())
hook.stats[2].append(acts.abs().histc(40,0,10)) hook.stats[
:::
1)
set_seed(= nn.Sequential(*cnn_layers())
model = HooksCallback(append_stats, mod_filter=fc.risinstance(nn.Conv2d))
hc =[hc]); fit(model, xtra_cbs
accuracy | loss | epoch | train |
---|---|---|---|
0.213 | 2.379 | 0 | train |
0.100 | 21.771 | 0 | eval |
::: {.cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
# Thanks to @ste for initial version of histgram plotting code
def get_hist(h): return torch.stack(h.stats[2]).t().float().log1p()
:::
= get_grid(len(hc), figsize=(11,5))
fig,axes for ax,h in zip(axes.flat, hc):
='lower') show_image(get_hist(h), ax, origin
::: {.cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
def get_min(h):
= torch.stack(h.stats[2]).t().float()
h1 return h1[0]/h1.sum(0)
:::
= get_grid(len(hc), figsize=(11,5))
fig,axes for ax,h in zip(axes.flatten(), hc):
ax.plot(get_min(h))0,1) ax.set_ylim(
ActivationStats
::: {.cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
class ActivationStats(HooksCallback):
def __init__(self, mod_filter=fc.noop): super().__init__(append_stats, mod_filter)
def color_dim(self, figsize=(11,5)):
= get_grid(len(self), figsize=figsize)
fig,axes for ax,h in zip(axes.flat, self):
='lower')
show_image(get_hist(h), ax, origin
def dead_chart(self, figsize=(11,5)):
= get_grid(len(self), figsize=figsize)
fig,axes for ax,h in zip(axes.flatten(), self):
ax.plot(get_min(h))0,1)
ax.set_ylim(
def plot_stats(self, figsize=(10,4)):
= plt.subplots(1,2, figsize=figsize)
fig,axs for h in self:
for i in 0,1: axs[i].plot(h.stats[i])
0].set_title('Means')
axs[1].set_title('Stdevs')
axs[range(self)) plt.legend(fc.L.
:::
= ActivationStats(fc.risinstance(nn.Conv2d)) astats
1)
set_seed(= nn.Sequential(*cnn_layers())
model =[astats]); fit(model, xtra_cbs
accuracy | loss | epoch | train |
---|---|---|---|
0.208 | 2.199 | 0 | train |
0.289 | 1.695 | 0 | eval |
astats.color_dim()
astats.dead_chart()
astats.plot_stats()
Export -
import nbdev; nbdev.nbdev_export()