![Writing better code with pytorch and einops]()
Writing better code with pytorch and einops
Rewriting building blocks of deep learning
Below are some fragments of code taken from official tutorials and popular repositories (fragments taken for educational purposes, sometimes shortened). For each fragment an enhanced version proposed with comments.
In most examples, einops was used to make things less complicated. But you'll also find some common recommendations and practices to improve the code.
Left: as it was, Right : improved version
# start from importing some stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
from einops import rearrange, reduce, asnumpy, parse_shape
from einops.layers.torch import Rearrange, Reduce
Simple ConvNet
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x, dim=1)
conv_net_old = Net()
conv_net_new = nn.Sequential(
nn.Conv2d(1, 10, kernel_size=5),
nn.MaxPool2d(kernel_size=2),
nn.ReLU(),
nn.Conv2d(10, 20, kernel_size=5),
nn.MaxPool2d(kernel_size=2),
nn.ReLU(),
nn.Dropout2d(),
Rearrange('b c h w -> b (c h w)'),
nn.Linear(320, 50),
nn.ReLU(),
nn.Dropout(),
nn.Linear(50, 10),
nn.LogSoftmax(dim=1)
)
Reasons to prefer new code:
in the original code if input size is changed and batch size is divisible by 16 (that's usualy so), we'll get something senseless after reshaping
new code explicitly drops error in this case
we won't forget to use dropout with flag self.training with new version
code is straightforward to read and analyze
sequential makes printing / saving / passing trivial. And there is no need in your code to load a model
... and we could also add inplace for ReLU
Super-resolution
class SuperResolutionNetOld(nn.Module):
def __init__(self, upscale_factor):
super(SuperResolutionNetOld, self).__init__()
self.relu = nn.ReLU()
self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))
self.pixel_shuffle = nn.PixelShuffle(upscale_factor)
def forward(self, x):
x = self.relu(self.conv1(x))
x = self.relu(self.conv2(x))
x = self.relu(self.conv3(x))
x = self.pixel_shuffle(self.conv4(x))
return x
def SuperResolutionNetNew(upscale_factor):
return nn.Sequential(
nn.Conv2d(1, 64, kernel_size=5, padding=2),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(64, 32, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(32, upscale_factor ** 2, kernel_size=3, padding=1),
Rearrange('b (h2 w2) h w -> b (h h2) (w w2)', h2=upscale_factor, w2=upscale_factor),
)
Here is the difference:
no need in special instruction pixel_shuffle (and result is transferrable between frameworks)
output doesn't contain a fake axis (and we could do the same for the input)
inplace ReLU used now, for high resolution pictures that becomes critical and saves us much memory
and all the benefits of nn.Sequential again
Restyling Gram matrix for style transfer
Original code is already good - its first line shows expected tensor shape
einsum operation should be read like:
for each batch and for each pair of channels, we sum over h and w.
I've also changed normalization, because that's how Gram matrix is defined, otherwise we should call it normalized Gram matrix or alike
def gram_matrix_old(y):
(b, ch, h, w) = y.size()
features = y.view(b, ch, w * h)
features_t = features.transpose(1, 2)
gram = features.bmm(features_t) / (ch * h * w)
return gram
def gram_matrix_new(y):
b, ch, h, w = y.shape
return torch.einsum('bchw,bdhw->bcd', [y, y]) / (h * w)
It would be great to use just 'b c1 h w,b c2 h w->b c1 c2' , but einsum supports only one-letter axes
Recurrent model
All we did here is just made information about shapes explicit to skip deciphering
class RNNModelOld(nn.Module):
"""Container module with an encoder, a recurrent module, and a decoder."""
def __init__(self, ntoken, ninp, nhid, nlayers, dropout=0.5):
super(RNNModel, self).__init__()
self.drop = nn.Dropout(dropout)
self.encoder = nn.Embedding(ntoken, ninp)
self.rnn = nn.LSTM(ninp, nhid, nlayers, dropout=dropout)
self.decoder = nn.Linear(nhid, ntoken)
def forward(self, input, hidden):
emb = self.drop(self.encoder(input))
output, hidden = self.rnn(emb, hidden)
output = self.drop(output)
decoded = self.decoder(output.view(output.size(0)*output.size(1), output.size(2)))
return decoded.view(output.size(0), output.size(1), decoded.size(1)), hidden
class RNNModelNew(nn.Module):
"""Container module with an encoder, a recurrent module, and a decoder."""
def __init__(self, ntoken, ninp, nhid, nlayers, dropout=0.5):
super(RNNModel, self).__init__()
self.drop = nn.Dropout(p=dropout)
self.encoder = nn.Embedding(ntoken, ninp)
self.rnn = nn.LSTM(ninp, nhid, nlayers, dropout=dropout)
self.decoder = nn.Linear(nhid, ntoken)
def forward(self, input, hidden):
t, b = input.shape
emb = self.drop(self.encoder(input))
output, hidden = self.rnn(emb, hidden)
output = rearrange(self.drop(output), 't b nhid -> (t b) nhid')
decoded = rearrange(self.decoder(output), '(t b) token -> t b token', t=t, b=b)
return decoded, hidden
Channel shuffle (from shufflenet)
def channel_shuffle_old(x, groups):
batchsize, num_channels, height, width = x.data.size()
channels_per_group = num_channels // groups
# reshape
x = x.view(batchsize, groups,
channels_per_group, height, width)
# transpose
# - contiguous() required if transpose() is used before view().
# See https://github.com/pytorch/pytorch/issues/764
x = torch.transpose(x, 1, 2).contiguous()
# flatten
x = x.view(batchsize, -1, height, width)
return x
def channel_shuffle_new(x, groups):
return rearrange(x, 'b (c1 c2) h w -> b (c2 c1) h w', c1=groups)
While progress is obvious, this is not the limit. As you'll see below, we don't even need to write these couple of lines.
Shufflenet
from collections import OrderedDict
def channel_shuffle(x, groups):
batchsize, num_channels, height, width = x.data.size()
channels_per_group = num_channels // groups
# reshape
x = x.view(batchsize, groups,
channels_per_group, height, width)
# transpose
# - contiguous() required if transpose() is used before view().
# See https://github.com/pytorch/pytorch/issues/764
x = torch.transpose(x, 1, 2).contiguous()
# flatten
x = x.view(batchsize, -1, height, width)
return x
class ShuffleUnitOld(nn.Module):
def __init__(self, in_channels, out_channels, groups=3,
grouped_conv=True, combine='add'):
super(ShuffleUnitOld, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.grouped_conv = grouped_conv
self.combine = combine
self.groups = groups
self.bottleneck_channels = self.out_channels // 4
# define the type of ShuffleUnit
if self.combine == 'add':
# ShuffleUnit Figure 2b
self.depthwise_stride = 1
self._combine_func = self._add
elif self.combine == 'concat':
# ShuffleUnit Figure 2c
self.depthwise_stride = 2
self._combine_func = self._concat
# ensure output of concat has the same channels as
# original output channels.
self.out_channels -= self.in_channels
else:
raise ValueError("Cannot combine tensors with \"{}\"" \
"Only \"add\" and \"concat\" are" \
"supported".format(self.combine))
# Use a 1x1 grouped or non-grouped convolution to reduce input channels
# to bottleneck channels, as in a ResNet bottleneck module.
# NOTE: Do not use group convolution for the first conv1x1 in Stage 2.
self.first_1x1_groups = self.groups if grouped_conv else 1
self.g_conv_1x1_compress = self._make_grouped_conv1x1(
self.in_channels,
self.bottleneck_channels,
self.first_1x1_groups,
batch_norm=True,
relu=True
)
# 3x3 depthwise convolution followed by batch normalization
self.depthwise_conv3x3 = conv3x3(
self.bottleneck_channels, self.bottleneck_channels,
stride=self.depthwise_stride, groups=self.bottleneck_channels)
self.bn_after_depthwise = nn.BatchNorm2d(self.bottleneck_channels)
# Use 1x1 grouped convolution to expand from
# bottleneck_channels to out_channels
self.g_conv_1x1_expand = self._make_grouped_conv1x1(
self.bottleneck_channels,
self.out_channels,
self.groups,
batch_norm=True,
relu=False
)
@staticmethod
def _add(x, out):
# residual connection
return x + out
@staticmethod
def _concat(x, out):
# concatenate along channel axis
return torch.cat((x, out), 1)
def _make_grouped_conv1x1(self, in_channels, out_channels, groups,
batch_norm=True, relu=False):
modules = OrderedDict()
conv = conv1x1(in_channels, out_channels, groups=groups)
modules['conv1x1'] = conv
if batch_norm:
modules['batch_norm'] = nn.BatchNorm2d(out_channels)
if relu:
modules['relu'] = nn.ReLU()
if len(modules) > 1:
return nn.Sequential(modules)
else:
return conv
def forward(self, x):
# save for combining later with output
residual = x
if self.combine == 'concat':
residual = F.avg_pool2d(residual, kernel_size=3,
stride=2, padding=1)
out = self.g_conv_1x1_compress(x)
out = channel_shuffle(out, self.groups)
out = self.depthwise_conv3x3(out)
out = self.bn_after_depthwise(out)
out = self.g_conv_1x1_expand(out)
out = self._combine_func(residual, out)
return F.relu(out)
class ShuffleUnitNew(nn.Module):
def __init__(self, in_channels, out_channels, groups=3,
grouped_conv=True, combine='add'):
super().__init__()
first_1x1_groups = groups if grouped_conv else 1
bottleneck_channels = out_channels // 4
self.combine = combine
if combine == 'add':
# ShuffleUnit Figure 2b
self.left = Rearrange('...->...') # identity
depthwise_stride = 1
else:
# ShuffleUnit Figure 2c
self.left = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
depthwise_stride = 2
# ensure output of concat has the same channels as original output channels.
out_channels -= in_channels
assert out_channels > 0
self.right = nn.Sequential(
# Use a 1x1 grouped or non-grouped convolution to reduce input channels
# to bottleneck channels, as in a ResNet bottleneck module.
conv1x1(in_channels, bottleneck_channels, groups=first_1x1_groups),
nn.BatchNorm2d(bottleneck_channels),
nn.ReLU(inplace=True),
# channel shuffle
Rearrange('b (c1 c2) h w -> b (c2 c1) h w', c1=groups),
# 3x3 depthwise convolution followed by batch
conv3x3(bottleneck_channels, bottleneck_channels,
stride=depthwise_stride, groups=bottleneck_channels),
nn.BatchNorm2d(bottleneck_channels),
# Use 1x1 grouped convolution to expand from
# bottleneck_channels to out_channels
conv1x1(bottleneck_channels, out_channels, groups=groups),
nn.BatchNorm2d(out_channels),
)
def forward(self, x):
if self.combine == 'add':
combined = self.left(x) + self.right(x)
else:
combined = torch.cat([self.left(x), self.right(x)], dim=1)
return F.relu(combined, inplace=True)
Rewriting the code helped to identify:
There is no sense in doing reshuffling and not using groups in the first (indeed, I in the paper it is not so). However, this is equivalent model.
It is also strange that first convolution may be not grouped, while last convolution is always grouped (and that is different from the paper)
Other comments:
You've probably noticed that there is an identity layer for pytorch introduced here
The last thing left is get rid of conv1x1 and conv3x3 in the code - those are not better than standard
Simplifying ResNet
class ResNetOld(nn.Module):
def __init__(self, block, layers, num_classes=1000):
self.inplanes = 64
super(ResNetOld, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
self.avgpool = nn.AvgPool2d(7, stride=1)
self.fc = nn.Linear(512 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
def make_layer(inplanes, planes, block, n_blocks, stride=1):
downsample = None
if stride != 1 or inplanes != planes * block.expansion:
# output size won't match input, so adjust residual
downsample = nn.Sequential(
nn.Conv2d(inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion),
)
return nn.Sequential(
block(inplanes, planes, stride, downsample),
*[block(planes * block.expansion, planes) for _ in range(1, n_blocks)]
)
def ResNetNew(block, layers, num_classes=1000):
e = block.expansion
resnet = nn.Sequential(
Rearrange('b c h w -> b c h w', c=3, h=224, w=224),
nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
make_layer(64, 64, block, layers[0], stride=1),
make_layer(64 * e, 128, block, layers[1], stride=2),
make_layer(128 * e, 256, block, layers[2], stride=2),
make_layer(256 * e, 512, block, layers[3], stride=2),
# combined AvgPool and view in one averaging operation
Reduce('b c h w -> b c', 'mean'),
nn.Linear(512 * e, num_classes),
)
# initialization
for m in resnet.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
return resnet
Things that were changed
make_layer
Improving RNN language modelling
class RNNOld(nn.Module):
def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim, n_layers, bidirectional, dropout):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.rnn = nn.LSTM(embedding_dim, hidden_dim, num_layers=n_layers,
bidirectional=bidirectional, dropout=dropout)
self.fc = nn.Linear(hidden_dim*2, output_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
#x = [sent len, batch size]
embedded = self.dropout(self.embedding(x))
#embedded = [sent len, batch size, emb dim]
output, (hidden, cell) = self.rnn(embedded)
#output = [sent len, batch size, hid dim * num directions]
#hidden = [num layers * num directions, batch size, hid dim]
#cell = [num layers * num directions, batch size, hid dim]
#concat the final forward (hidden[-2,:,:]) and backward (hidden[-1,:,:]) hidden layers
#and apply dropout
hidden = self.dropout(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1))
#hidden = [batch size, hid dim * num directions]
return self.fc(hidden.squeeze(0))
class RNNNew(nn.Module):
def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim, n_layers, bidirectional, dropout):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.rnn = nn.LSTM(embedding_dim, hidden_dim, num_layers=n_layers,
bidirectional=bidirectional, dropout=dropout)
self.dropout = nn.Dropout(dropout)
self.directions = 2 if bidirectional else 1
self.fc = nn.Linear(hidden_dim * self.directions, output_dim)
def forward(self, x):
#x = [sent len, batch size]
embedded = self.dropout(self.embedding(x))
#embedded = [sent len, batch size, emb dim]
output, (hidden, cell) = self.rnn(embedded)
hidden = rearrange(hidden, '(layer dir) b c -> layer b (dir c)',
dir=self.directions)
# take the final layer's hidden
return self.fc(self.dropout(hidden[-1]))
original code misbehaves for non-bidirectional models
and fails when bidirectional = False, and there is only one layer
modification of the code shows both how hidden is structured and how it is modified
Writing FastText faster
class FastTextOld(nn.Module):
def __init__(self, vocab_size, embedding_dim, output_dim):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.fc = nn.Linear(embedding_dim, output_dim)
def forward(self, x):
#x = [sent len, batch size]
embedded = self.embedding(x)
#embedded = [sent len, batch size, emb dim]
embedded = embedded.permute(1, 0, 2)
#embedded = [batch size, sent len, emb dim]
pooled = F.avg_pool2d(embedded, (embedded.shape[1], 1)).squeeze(1)
#pooled = [batch size, embedding_dim]
return self.fc(pooled)
def FastTextNew(vocab_size, embedding_dim, output_dim):
return nn.Sequential(
Rearrange('t b -> t b'),
nn.Embedding(vocab_size, embedding_dim),
Reduce('t b c -> b c', 'mean'),
nn.Linear(embedding_dim, output_dim),
Rearrange('b c -> b c'),
)
Some comments on new code:
Rearrange('b t -> t b'),
CNNs for text classification
class CNNOld(nn.Module):
def __init__(self, vocab_size, embedding_dim, n_filters, filter_sizes, output_dim, dropout):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.conv_0 = nn.Conv2d(in_channels=1, out_channels=n_filters, kernel_size=(filter_sizes[0],embedding_dim))
self.conv_1 = nn.Conv2d(in_channels=1, out_channels=n_filters, kernel_size=(filter_sizes[1],embedding_dim))
self.conv_2 = nn.Conv2d(in_channels=1, out_channels=n_filters, kernel_size=(filter_sizes[2],embedding_dim))
self.fc = nn.Linear(len(filter_sizes)*n_filters, output_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
#x = [sent len, batch size]
x = x.permute(1, 0)
#x = [batch size, sent len]
embedded = self.embedding(x)
#embedded = [batch size, sent len, emb dim]
embedded = embedded.unsqueeze(1)
#embedded = [batch size, 1, sent len, emb dim]
conved_0 = F.relu(self.conv_0(embedded).squeeze(3))
conved_1 = F.relu(self.conv_1(embedded).squeeze(3))
conved_2 = F.relu(self.conv_2(embedded).squeeze(3))
#conv_n = [batch size, n_filters, sent len - filter_sizes[n]]
pooled_0 = F.max_pool1d(conved_0, conved_0.shape[2]).squeeze(2)
pooled_1 = F.max_pool1d(conved_1, conved_1.shape[2]).squeeze(2)
pooled_2 = F.max_pool1d(conved_2, conved_2.shape[2]).squeeze(2)
#pooled_n = [batch size, n_filters]
cat = self.dropout(torch.cat((pooled_0, pooled_1, pooled_2), dim=1))
#cat = [batch size, n_filters * len(filter_sizes)]
return self.fc(cat)
class CNNNew(nn.Module):
def __init__(self, vocab_size, embedding_dim, n_filters, filter_sizes, output_dim, dropout):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.convs = nn.ModuleList([
nn.Conv1d(embedding_dim, n_filters, kernel_size=size) for size in filter_sizes
])
self.fc = nn.Linear(len(filter_sizes) * n_filters, output_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
x = rearrange(x, 't b -> t b')
emb = rearrange(self.embedding(x), 't b c -> b c t')
pooled = [reduce(conv(emb), 'b c t -> b c', 'max') for conv in self.convs]
concatenated = rearrange(pooled, 'filter b c -> b (filter c)')
return self.fc(self.dropout(F.relu(concatenated)))
Original code misuses Conv2d, while Conv1d is the right choice
Fixed code can work with any number of filter_sizes (and won't fail)
First line in new code does nothing, but was added for simplicity
Highway convolutions
Highway convolutions are common in TTS systems. Code below makes splitting a bit more explicit.
Splitting policy may eventually turn out to be important if input had previously groups over channel axes (group convolutions or bidirectional LSTMs/GRUs)
Same applies to GLU and gated units in general
class HighwayConv1dOld(nn.Conv1d):
def forward(self, inputs):
L = super(HighwayConv1dOld, self).forward(inputs)
H1, H2 = torch.chunk(L, 2, 1) # chunk at the feature dim
torch.sigmoid_(H1)
return H1 * H2 + (1.0 - H1) * inputs
class HighwayConv1dNew(nn.Conv1d):
def forward(self, inputs):
L = super().forward(inputs)
H1, H2 = rearrange(L, 'b (split c) t -> split b c t', split=2)
torch.sigmoid_(H1)
return H1 * H2 + (1.0 - H1) * inputs
Tacotron's CBHG module
class CBHG_Old(nn.Module):
"""CBHG module: a recurrent neural network composed of:
- 1-d convolution banks
- Highway networks + residual connections
- Bidirectional gated recurrent units
"""
def __init__(self, in_dim, K=16, projections=[128, 128]):
super(CBHG, self).__init__()
self.in_dim = in_dim
self.relu = nn.ReLU()
self.conv1d_banks = nn.ModuleList(
[BatchNormConv1d(in_dim, in_dim, kernel_size=k, stride=1,
padding=k // 2, activation=self.relu)
for k in range(1, K + 1)])
self.max_pool1d = nn.MaxPool1d(kernel_size=2, stride=1, padding=1)
in_sizes = [K * in_dim] + projections[:-1]
activations = [self.relu] * (len(projections) - 1) + [None]
self.conv1d_projections = nn.ModuleList(
[BatchNormConv1d(in_size, out_size, kernel_size=3, stride=1,
padding=1, activation=ac)
for (in_size, out_size, ac) in zip(
in_sizes, projections, activations)])
self.pre_highway = nn.Linear(projections[-1], in_dim, bias=False)
self.highways = nn.ModuleList(
[Highway(in_dim, in_dim) for _ in range(4)])
self.gru = nn.GRU(
in_dim, in_dim, 1, batch_first=True, bidirectional=True)
def forward_old(self, inputs):
# (B, T_in, in_dim)
x = inputs
# Needed to perform conv1d on time-axis
# (B, in_dim, T_in)
if x.size(-1) == self.in_dim:
x = x.transpose(1, 2)
T = x.size(-1)
# (B, in_dim*K, T_in)
# Concat conv1d bank outputs
x = torch.cat([conv1d(x)[:, :, :T] for conv1d in self.conv1d_banks], dim=1)
assert x.size(1) == self.in_dim * len(self.conv1d_banks)
x = self.max_pool1d(x)[:, :, :T]
for conv1d in self.conv1d_projections:
x = conv1d(x)
# (B, T_in, in_dim)
# Back to the original shape
x = x.transpose(1, 2)
if x.size(-1) != self.in_dim:
x = self.pre_highway(x)
# Residual connection
x += inputs
for highway in self.highways:
x = highway(x)
# (B, T_in, in_dim*2)
outputs, _ = self.gru(x)
return outputs
def forward_new(self, inputs, input_lengths=None):
x = rearrange(inputs, 'b t c -> b c t')
_, _, T = x.shape
# Concat conv1d bank outputs
x = rearrange([conv1d(x)[:, :, :T] for conv1d in self.conv1d_banks],
'bank b c t -> b (bank c) t', c=self.in_dim)
x = self.max_pool1d(x)[:, :, :T]
for conv1d in self.conv1d_projections:
x = conv1d(x)
x = rearrange(x, 'b c t -> b t c')
if x.size(-1) != self.in_dim:
x = self.pre_highway(x)
# Residual connection
x += inputs
for highway in self.highways:
x = highway(x)
# (B, T_in, in_dim*2)
outputs, _ = self.gru(self.highways(x))
return outputs
There is still a large room for improvements, but in this example only forward function was changed
Simple attention
Good news: there is no more need to guess order of dimensions. Neither for inputs nor for outputs
class Attention(nn.Module):
def __init__(self):
super(Attention, self).__init__()
def forward(self, K, V, Q):
A = torch.bmm(K.transpose(1,2), Q) / np.sqrt(Q.shape[1])
A = F.softmax(A, 1)
R = torch.bmm(V, A)
return torch.cat((R, Q), dim=1)
def attention(K, V, Q):
_, n_channels, _ = K.shape
A = torch.einsum('bct,bcl->btl', [K, Q])
A = F.softmax(A * n_channels ** (-0.5), 1)
R = torch.einsum('bct,btl->bcl', [V, A])
return torch.cat((R, Q), dim=1)
Transformer's attention needs more attention
class ScaledDotProductAttention(nn.Module):
''' Scaled Dot-Product Attention '''
def __init__(self, temperature, attn_dropout=0.1):
super().__init__()
self.temperature = temperature
self.dropout = nn.Dropout(attn_dropout)
self.softmax = nn.Softmax(dim=2)
def forward(self, q, k, v, mask=None):
attn = torch.bmm(q, k.transpose(1, 2))
attn = attn / self.temperature
if mask is not None:
attn = attn.masked_fill(mask, -np.inf)
attn = self.softmax(attn)
attn = self.dropout(attn)
output = torch.bmm(attn, v)
return output, attn
class MultiHeadAttentionOld(nn.Module):
''' Multi-Head Attention module '''
def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
super().__init__()
self.n_head = n_head
self.d_k = d_k
self.d_v = d_v
self.w_qs = nn.Linear(d_model, n_head * d_k)
self.w_ks = nn.Linear(d_model, n_head * d_k)
self.w_vs = nn.Linear(d_model, n_head * d_v)
nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v)))
self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5))
self.layer_norm = nn.LayerNorm(d_model)
self.fc = nn.Linear(n_head * d_v, d_model)
nn.init.xavier_normal_(self.fc.weight)
self.dropout = nn.Dropout(dropout)
def forward(self, q, k, v, mask=None):
d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
sz_b, len_q, _ = q.size()
sz_b, len_k, _ = k.size()
sz_b, len_v, _ = v.size()
residual = q
q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)
q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk
k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk
v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv
mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x ..
output, attn = self.attention(q, k, v, mask=mask)
output = output.view(n_head, sz_b, len_q, d_v)
output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) # b x lq x (n*dv)
output = self.dropout(self.fc(output))
output = self.layer_norm(output + residual)
return output, attn
class MultiHeadAttentionNew(nn.Module):
def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
super().__init__()
self.n_head = n_head
self.w_qs = nn.Linear(d_model, n_head * d_k)
self.w_ks = nn.Linear(d_model, n_head * d_k)
self.w_vs = nn.Linear(d_model, n_head * d_v)
nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v)))
self.fc = nn.Linear(n_head * d_v, d_model)
nn.init.xavier_normal_(self.fc.weight)
self.dropout = nn.Dropout(p=dropout)
self.layer_norm = nn.LayerNorm(d_model)
def forward(self, q, k, v, mask=None):
residual = q
q = rearrange(self.w_qs(q), 'b l (head k) -> head b l k', head=self.n_head)
k = rearrange(self.w_ks(k), 'b t (head k) -> head b t k', head=self.n_head)
v = rearrange(self.w_vs(v), 'b t (head v) -> head b t v', head=self.n_head)
attn = torch.einsum('hblk,hbtk->hblt', [q, k]) / np.sqrt(q.shape[-1])
if mask is not None:
attn = attn.masked_fill(mask[None], -np.inf)
attn = torch.softmax(attn, dim=3)
output = torch.einsum('hblt,hbtv->hblv', [attn, v])
output = rearrange(output, 'head b l v -> b l (head v)')
output = self.dropout(self.fc(output))
output = self.layer_norm(output + residual)
return output, attn
Benefits of new implementation
we have one module, not two
now code does not fail for None mask
the amount of caveats in the original code that we removed is huge. Try erasing comments and deciphering what happens there
Self-attention GANs
SAGANs are currently SotA for image generation, and can be simplified using same tricks. If torch.einsum supported non-one letter axes, we could improve this solution further.
class Self_Attn_Old(nn.Module):
""" Self attention Layer"""
def __init__(self,in_dim,activation):
super(Self_Attn_Old,self).__init__()
self.chanel_in = in_dim
self.activation = activation
self.query_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
self.key_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
self.value_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim , kernel_size= 1)
self.gamma = nn.Parameter(torch.zeros(1))
self.softmax = nn.Softmax(dim=-1) #
def forward(self, x):
"""
inputs :
x : input feature maps( B X C X W X H)
returns :
out : self attention value + input feature
attention: B X N X N (N is Width*Height)
"""
m_batchsize,C,width ,height = x.size()
proj_query = self.query_conv(x).view(m_batchsize,-1,width*height).permute(0,2,1) # B X CX(N)
proj_key = self.key_conv(x).view(m_batchsize,-1,width*height) # B X C x (*W*H)
energy = torch.bmm(proj_query,proj_key) # transpose check
attention = self.softmax(energy) # BX (N) X (N)
proj_value = self.value_conv(x).view(m_batchsize,-1,width*height) # B X C X N
out = torch.bmm(proj_value,attention.permute(0,2,1) )
out = out.view(m_batchsize,C,width,height)
out = self.gamma*out + x
return out,attention
class Self_Attn_New(nn.Module):
""" Self attention Layer"""
def __init__(self, in_dim):
super().__init__()
self.query_conv = nn.Conv2d(in_dim, out_channels=in_dim//8, kernel_size=1)
self.key_conv = nn.Conv2d(in_dim, out_channels=in_dim//8, kernel_size=1)
self.value_conv = nn.Conv2d(in_dim, out_channels=in_dim, kernel_size=1)
self.gamma = nn.Parameter(torch.zeros([1]))
def forward(self, x):
proj_query = rearrange(self.query_conv(x), 'b c h w -> b (h w) c')
proj_key = rearrange(self.key_conv(x), 'b c h w -> b c (h w)')
proj_value = rearrange(self.value_conv(x), 'b c h w -> b (h w) c')
energy = torch.bmm(proj_query, proj_key)
attention = F.softmax(energy, dim=2)
out = torch.bmm(attention, proj_value)
out = x + self.gamma * rearrange(out, 'b (h w) c -> b c h w',
**parse_shape(x, 'b c h w'))
return out, attention
Improving time sequence prediction
While this example was considered to be simplistic, I had to analyze surrounding code to understand what kind of input was expected. You can try yourself.
One minor change done is now the code works with any dtype, not only double; and new code supports using GPU.
class SequencePredictionOld(nn.Module):
def __init__(self):
super(SequencePredictionOld, self).__init__()
self.lstm1 = nn.LSTMCell(1, 51)
self.lstm2 = nn.LSTMCell(51, 51)
self.linear = nn.Linear(51, 1)
def forward(self, input, future = 0):
outputs = []
h_t = torch.zeros(input.size(0), 51, dtype=torch.double)
c_t = torch.zeros(input.size(0), 51, dtype=torch.double)
h_t2 = torch.zeros(input.size(0), 51, dtype=torch.double)
c_t2 = torch.zeros(input.size(0), 51, dtype=torch.double)
for i, input_t in enumerate(input.chunk(input.size(1), dim=1)):
h_t, c_t = self.lstm1(input_t, (h_t, c_t))
h_t2, c_t2 = self.lstm2(h_t, (h_t2, c_t2))
output = self.linear(h_t2)
outputs += [output]
for i in range(future):# if we should predict the future
h_t, c_t = self.lstm1(output, (h_t, c_t))
h_t2, c_t2 = self.lstm2(h_t, (h_t2, c_t2))
output = self.linear(h_t2)
outputs += [output]
outputs = torch.stack(outputs, 1).squeeze(2)
return outputs
class SequencePredictionNew(nn.Module):
def __init__(self):
super(SequencePredictionNew, self).__init__()
self.lstm1 = nn.LSTMCell(1, 51)
self.lstm2 = nn.LSTMCell(51, 51)
self.linear = nn.Linear(51, 1)
def forward(self, input, future=0):
b, t = input.shape
h_t, c_t, h_t2, c_t2 = torch.zeros(4, b, 51, dtype=self.linear.weight.dtype,
device=self.linear.weight.device)
outputs = []
for input_t in rearrange(input, 'b t -> t b ()'):
h_t, c_t = self.lstm1(input_t, (h_t, c_t))
h_t2, c_t2 = self.lstm2(h_t, (h_t2, c_t2))
output = self.linear(h_t2)
outputs += [output]
for i in range(future): # if we should predict the future
h_t, c_t = self.lstm1(output, (h_t, c_t))
h_t2, c_t2 = self.lstm2(h_t, (h_t2, c_t2))
output = self.linear(h_t2)
outputs += [output]
return rearrange(outputs, 't b () -> b t')
Transforming spacial transformer network (STN)
class SpacialTransformOld(nn.Module):
def __init__(self):
super(Net, self).__init__()
# Spatial transformer localization-network
self.localization = nn.Sequential(
nn.Conv2d(1, 8, kernel_size=7),
nn.MaxPool2d(2, stride=2),
nn.ReLU(True),
nn.Conv2d(8, 10, kernel_size=5),
nn.MaxPool2d(2, stride=2),
nn.ReLU(True)
)
# Regressor for the 3 * 2 affine matrix
self.fc_loc = nn.Sequential(
nn.Linear(10 * 3 * 3, 32),
nn.ReLU(True),
nn.Linear(32, 3 * 2)
)
# Initialize the weights/bias with identity transformation
self.fc_loc[2].weight.data.zero_()
self.fc_loc[2].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))
# Spatial transformer network forward function
def stn(self, x):
xs = self.localization(x)
xs = xs.view(-1, 10 * 3 * 3)
theta = self.fc_loc(xs)
theta = theta.view(-1, 2, 3)
grid = F.affine_grid(theta, x.size())
x = F.grid_sample(x, grid)
return x
class SpacialTransformNew(nn.Module):
def __init__(self):
super(Net, self).__init__()
# Spatial transformer localization-network
linear = nn.Linear(32, 3 * 2)
# Initialize the weights/bias with identity transformation
linear.weight.data.zero_()
linear.bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))
self.compute_theta = nn.Sequential(
nn.Conv2d(1, 8, kernel_size=7),
nn.MaxPool2d(2, stride=2),
nn.ReLU(True),
nn.Conv2d(8, 10, kernel_size=5),
nn.MaxPool2d(2, stride=2),
nn.ReLU(True),
Rearrange('b c h w -> b (c h w)', h=3, w=3),
nn.Linear(10 * 3 * 3, 32),
nn.ReLU(True),
linear,
Rearrange('b (row col) -> b row col', row=2, col=3),
)
# Spatial transformer network forward function
def stn(self, x):
grid = F.affine_grid(self.compute_theta(x), x.size())
return F.grid_sample(x, grid)
new code will give reasonable errors when passed image size is different from expected
if batch size is divisible by 18, whatever you input in the old code, it'll fail no sooner than affine_grid.
Improving GLOW
That's a good old depth-to-space written manually!
Since GLOW is revertible, it will frequently rely on rearrange -like operations.
def unsqueeze2d_old(input, factor=2):
assert factor >= 1 and isinstance(factor, int)
factor2 = factor ** 2
if factor == 1:
return input
size = input.size()
B = size[0]
C = size[1]
H = size[2]
W = size[3]
assert C % (factor2) == 0, "{}".format(C)
x = input.view(B, C // factor2, factor, factor, H, W)
x = x.permute(0, 1, 4, 2, 5, 3).contiguous()
x = x.view(B, C // (factor2), H * factor, W * factor)
return x
def squeeze2d_old(input, factor=2):
assert factor >= 1 and isinstance(factor, int)
if factor == 1:
return input
size = input.size()
B = size[0]
C = size[1]
H = size[2]
W = size[3]
assert H % factor == 0 and W % factor == 0, "{}".format((H, W))
x = input.view(B, C, H // factor, factor, W // factor, factor)
x = x.permute(0, 1, 3, 5, 2, 4).contiguous()
x = x.view(B, C * factor * factor, H // factor, W // factor)
return x
def unsqueeze2d_new(input, factor=2):
return rearrange(input, 'b (c h2 w2) h w -> b c (h h2) (w w2)', h2=factor, w2=factor)
def squeeze2d_new(input, factor=2):
return rearrange(input, 'b c (h h2) (w w2) -> b (c h2 w2) h w', h2=factor, w2=factor)
term squeeze isn't very helpful: which dimension is squeezed? There is torch.squeeze , but it's very different.
in fact, we could skip creating functions completely
Detecting problems in YOLO detection
def YOLO_prediction_old(input, num_classes, num_anchors, anchors, stride_h, stride_w):
bs = input.size(0)
in_h = input.size(2)
in_w = input.size(3)
scaled_anchors = [(a_w / stride_w, a_h / stride_h) for a_w, a_h in anchors]
prediction = input.view(bs, num_anchors,
5 + num_classes, in_h, in_w).permute(0, 1, 3, 4, 2).contiguous()
# Get outputs
x = torch.sigmoid(prediction[..., 0]) # Center x
y = torch.sigmoid(prediction[..., 1]) # Center y
w = prediction[..., 2] # Width
h = prediction[..., 3] # Height
conf = torch.sigmoid(prediction[..., 4]) # Conf
pred_cls = torch.sigmoid(prediction[..., 5:]) # Cls pred.
FloatTensor = torch.cuda.FloatTensor if x.is_cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if x.is_cuda else torch.LongTensor
# Calculate offsets for each grid
grid_x = torch.linspace(0, in_w - 1, in_w).repeat(in_w, 1).repeat(
bs * num_anchors, 1, 1).view(x.shape).type(FloatTensor)
grid_y = torch.linspace(0, in_h - 1, in_h).repeat(in_h, 1).t().repeat(
bs * num_anchors, 1, 1).view(y.shape).type(FloatTensor)
# Calculate anchor w, h
anchor_w = FloatTensor(scaled_anchors).index_select(1, LongTensor([0]))
anchor_h = FloatTensor(scaled_anchors).index_select(1, LongTensor([1]))
anchor_w = anchor_w.repeat(bs, 1).repeat(1, 1, in_h * in_w).view(w.shape)
anchor_h = anchor_h.repeat(bs, 1).repeat(1, 1, in_h * in_w).view(h.shape)
# Add offset and scale with anchors
pred_boxes = FloatTensor(prediction[..., :4].shape)
pred_boxes[..., 0] = x.data + grid_x
pred_boxes[..., 1] = y.data + grid_y
pred_boxes[..., 2] = torch.exp(w.data) * anchor_w
pred_boxes[..., 3] = torch.exp(h.data) * anchor_h
# Results
_scale = torch.Tensor([stride_w, stride_h] * 2).type(FloatTensor)
output = torch.cat((pred_boxes.view(bs, -1, 4) * _scale,
conf.view(bs, -1, 1), pred_cls.view(bs, -1, num_classes)), -1)
return output
def YOLO_prediction_new(input, num_classes, num_anchors, anchors, stride_h, stride_w):
raw_predictions = rearrange(input, 'b (anchor prediction) h w -> prediction b anchor h w',
anchor=num_anchors, prediction=5 + num_classes)
anchors = torch.FloatTensor(anchors).to(input.device)
anchor_sizes = rearrange(anchors, 'anchor dim -> dim () anchor () ()')
_, _, _, in_h, in_w = raw_predictions.shape
grid_h = rearrange(torch.arange(in_h).float(), 'h -> () () h ()').to(input.device)
grid_w = rearrange(torch.arange(in_w).float(), 'w -> () () () w').to(input.device)
predicted_bboxes = torch.zeros_like(raw_predictions)
predicted_bboxes[0] = (raw_predictions[0].sigmoid() + grid_w) * stride_w # center x
predicted_bboxes[1] = (raw_predictions[1].sigmoid() + grid_h) * stride_h # center y
predicted_bboxes[2:4] = (raw_predictions[2:4].exp()) * anchor_sizes # bbox width and height
predicted_bboxes[4] = raw_predictions[4].sigmoid() # confidence
predicted_bboxes[5:] = raw_predictions[5:].sigmoid() # class predictions
# merging all predicted bboxes for each image
return rearrange(predicted_bboxes, 'prediction b anchor h w -> b (anchor h w) prediction')
We changed and fixed a lot:
new code won't fail if input is not on the first GPU
old code has wrong grid_x and grid_y for non-square images
new code doesn't use replication when broadcasting is sufficient
old code strangely sometimes takes .data , but this has no real effect, as some branches preserve gradient till the end
if gradients not needed, torch.no_grad should be used, so it's redundant
Simpler output for a bunch of pictures
Next time you need to output drawings of you generative models, you can use this trick
device = 'cpu'
plt.imshow(np.transpose(vutils.make_grid(fake_batch.to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))
padded = F.pad(fake_batch[:64], [1, 1, 1, 1])
plt.imshow(rearrange(padded, '(b1 b2) c h w -> (b1 h) (b2 w) c', b1=8).cpu())
Instead of conclusion
Better code is a vague term; to be specific, things that are expected from code are:
reliable: does what expected and does not fail. Explicitly fails for wrong inputs
readaility counts
maintainable and modifiable
reusable: understanding and modifying code should be easier than writing from scratch
fast: in my measurements, proposed versions have speed similar to the original code
I've tried to demonstrate how you can improve these criteria for deep learning code. And einops helps you a lot.
Links
pytorch and einops
significant part of the code was taken from official examples and tutorials
(references for other code are given in source of this html, if you're really curious)
einops has a tutorial if you want a gentle introduction