Quantcast
Channel: CodeSection,代码区,Python开发技术文章_教程 - CodeSec
Viewing all articles
Browse latest Browse all 9596

Writing better code with pytorch and einops

$
0
0

Writing better code with pytorch and einops
Writing better code with pytorch and einops Rewriting building blocks of deep learning

Below are some fragments of code taken from official tutorials and popular repositories (fragments taken for educational purposes, sometimes shortened). For each fragment an enhanced version proposed with comments.

In most examples, einops was used to make things less complicated. But you'll also find some common recommendations and practices to improve the code.

Left: as it was, Right : improved version

# start from importing some stuff import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import math from einops import rearrange, reduce, asnumpy, parse_shape from einops.layers.torch import Rearrange, Reduce Simple ConvNet class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(1, 10, kernel_size=5) self.conv2 = nn.Conv2d(10, 20, kernel_size=5) self.conv2_drop = nn.Dropout2d() self.fc1 = nn.Linear(320, 50) self.fc2 = nn.Linear(50, 10) def forward(self, x): x = F.relu(F.max_pool2d(self.conv1(x), 2)) x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) x = x.view(-1, 320) x = F.relu(self.fc1(x)) x = F.dropout(x, training=self.training) x = self.fc2(x) return F.log_softmax(x, dim=1) conv_net_old = Net()

conv_net_new = nn.Sequential( nn.Conv2d(1, 10, kernel_size=5), nn.MaxPool2d(kernel_size=2), nn.ReLU(), nn.Conv2d(10, 20, kernel_size=5), nn.MaxPool2d(kernel_size=2), nn.ReLU(), nn.Dropout2d(), Rearrange('b c h w -> b (c h w)'), nn.Linear(320, 50), nn.ReLU(), nn.Dropout(), nn.Linear(50, 10), nn.LogSoftmax(dim=1) )

Reasons to prefer new code:

in the original code if input size is changed and batch size is divisible by 16 (that's usualy so), we'll get something senseless after reshaping new code explicitly drops error in this case we won't forget to use dropout with flag self.training with new version code is straightforward to read and analyze sequential makes printing / saving / passing trivial. And there is no need in your code to load a model ... and we could also add inplace for ReLU Super-resolution class SuperResolutionNetOld(nn.Module): def __init__(self, upscale_factor): super(SuperResolutionNetOld, self).__init__() self.relu = nn.ReLU() self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2)) self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)) self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1)) self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1)) self.pixel_shuffle = nn.PixelShuffle(upscale_factor) def forward(self, x): x = self.relu(self.conv1(x)) x = self.relu(self.conv2(x)) x = self.relu(self.conv3(x)) x = self.pixel_shuffle(self.conv4(x)) return x

def SuperResolutionNetNew(upscale_factor): return nn.Sequential( nn.Conv2d(1, 64, kernel_size=5, padding=2), nn.ReLU(inplace=True), nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(64, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(32, upscale_factor ** 2, kernel_size=3, padding=1), Rearrange('b (h2 w2) h w -> b (h h2) (w w2)', h2=upscale_factor, w2=upscale_factor), )

Here is the difference:

no need in special instruction pixel_shuffle (and result is transferrable between frameworks) output doesn't contain a fake axis (and we could do the same for the input) inplace ReLU used now, for high resolution pictures that becomes critical and saves us much memory and all the benefits of nn.Sequential again Restyling Gram matrix for style transfer

Original code is already good - its first line shows expected tensor shape

einsum operation should be read like: for each batch and for each pair of channels, we sum over h and w. I've also changed normalization, because that's how Gram matrix is defined, otherwise we should call it normalized Gram matrix or alike def gram_matrix_old(y): (b, ch, h, w) = y.size() features = y.view(b, ch, w * h) features_t = features.transpose(1, 2) gram = features.bmm(features_t) / (ch * h * w) return gram def gram_matrix_new(y): b, ch, h, w = y.shape return torch.einsum('bchw,bdhw->bcd', [y, y]) / (h * w)

It would be great to use just 'b c1 h w,b c2 h w->b c1 c2' , but einsum supports only one-letter axes

Recurrent model

All we did here is just made information about shapes explicit to skip deciphering

class RNNModelOld(nn.Module): """Container module with an encoder, a recurrent module, and a decoder.""" def __init__(self, ntoken, ninp, nhid, nlayers, dropout=0.5): super(RNNModel, self).__init__() self.drop = nn.Dropout(dropout) self.encoder = nn.Embedding(ntoken, ninp) self.rnn = nn.LSTM(ninp, nhid, nlayers, dropout=dropout) self.decoder = nn.Linear(nhid, ntoken) def forward(self, input, hidden): emb = self.drop(self.encoder(input)) output, hidden = self.rnn(emb, hidden) output = self.drop(output) decoded = self.decoder(output.view(output.size(0)*output.size(1), output.size(2))) return decoded.view(output.size(0), output.size(1), decoded.size(1)), hidden

class RNNModelNew(nn.Module): """Container module with an encoder, a recurrent module, and a decoder.""" def __init__(self, ntoken, ninp, nhid, nlayers, dropout=0.5): super(RNNModel, self).__init__() self.drop = nn.Dropout(p=dropout) self.encoder = nn.Embedding(ntoken, ninp) self.rnn = nn.LSTM(ninp, nhid, nlayers, dropout=dropout) self.decoder = nn.Linear(nhid, ntoken) def forward(self, input, hidden): t, b = input.shape emb = self.drop(self.encoder(input)) output, hidden = self.rnn(emb, hidden) output = rearrange(self.drop(output), 't b nhid -> (t b) nhid') decoded = rearrange(self.decoder(output), '(t b) token -> t b token', t=t, b=b) return decoded, hidden

Channel shuffle (from shufflenet) def channel_shuffle_old(x, groups): batchsize, num_channels, height, width = x.data.size() channels_per_group = num_channels // groups # reshape x = x.view(batchsize, groups, channels_per_group, height, width) # transpose # - contiguous() required if transpose() is used before view(). # See https://github.com/pytorch/pytorch/issues/764 x = torch.transpose(x, 1, 2).contiguous() # flatten x = x.view(batchsize, -1, height, width) return x

def channel_shuffle_new(x, groups): return rearrange(x, 'b (c1 c2) h w -> b (c2 c1) h w', c1=groups)

While progress is obvious, this is not the limit. As you'll see below, we don't even need to write these couple of lines.

Shufflenet from collections import OrderedDict def channel_shuffle(x, groups): batchsize, num_channels, height, width = x.data.size() channels_per_group = num_channels // groups # reshape x = x.view(batchsize, groups, channels_per_group, height, width) # transpose # - contiguous() required if transpose() is used before view(). # See https://github.com/pytorch/pytorch/issues/764 x = torch.transpose(x, 1, 2).contiguous() # flatten x = x.view(batchsize, -1, height, width) return x class ShuffleUnitOld(nn.Module): def __init__(self, in_channels, out_channels, groups=3, grouped_conv=True, combine='add'): super(ShuffleUnitOld, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.grouped_conv = grouped_conv self.combine = combine self.groups = groups self.bottleneck_channels = self.out_channels // 4 # define the type of ShuffleUnit if self.combine == 'add': # ShuffleUnit Figure 2b self.depthwise_stride = 1 self._combine_func = self._add elif self.combine == 'concat': # ShuffleUnit Figure 2c self.depthwise_stride = 2 self._combine_func = self._concat # ensure output of concat has the same channels as # original output channels. self.out_channels -= self.in_channels else: raise ValueError("Cannot combine tensors with \"{}\"" \ "Only \"add\" and \"concat\" are" \ "supported".format(self.combine)) # Use a 1x1 grouped or non-grouped convolution to reduce input channels # to bottleneck channels, as in a ResNet bottleneck module. # NOTE: Do not use group convolution for the first conv1x1 in Stage 2. self.first_1x1_groups = self.groups if grouped_conv else 1 self.g_conv_1x1_compress = self._make_grouped_conv1x1( self.in_channels, self.bottleneck_channels, self.first_1x1_groups, batch_norm=True, relu=True ) # 3x3 depthwise convolution followed by batch normalization self.depthwise_conv3x3 = conv3x3( self.bottleneck_channels, self.bottleneck_channels, stride=self.depthwise_stride, groups=self.bottleneck_channels) self.bn_after_depthwise = nn.BatchNorm2d(self.bottleneck_channels) # Use 1x1 grouped convolution to expand from # bottleneck_channels to out_channels self.g_conv_1x1_expand = self._make_grouped_conv1x1( self.bottleneck_channels, self.out_channels, self.groups, batch_norm=True, relu=False ) @staticmethod def _add(x, out): # residual connection return x + out @staticmethod def _concat(x, out): # concatenate along channel axis return torch.cat((x, out), 1) def _make_grouped_conv1x1(self, in_channels, out_channels, groups, batch_norm=True, relu=False): modules = OrderedDict() conv = conv1x1(in_channels, out_channels, groups=groups) modules['conv1x1'] = conv if batch_norm: modules['batch_norm'] = nn.BatchNorm2d(out_channels) if relu: modules['relu'] = nn.ReLU() if len(modules) > 1: return nn.Sequential(modules) else: return conv def forward(self, x): # save for combining later with output residual = x if self.combine == 'concat': residual = F.avg_pool2d(residual, kernel_size=3, stride=2, padding=1) out = self.g_conv_1x1_compress(x) out = channel_shuffle(out, self.groups) out = self.depthwise_conv3x3(out) out = self.bn_after_depthwise(out) out = self.g_conv_1x1_expand(out) out = self._combine_func(residual, out) return F.relu(out) class ShuffleUnitNew(nn.Module): def __init__(self, in_channels, out_channels, groups=3, grouped_conv=True, combine='add'): super().__init__() first_1x1_groups = groups if grouped_conv else 1 bottleneck_channels = out_channels // 4 self.combine = combine if combine == 'add': # ShuffleUnit Figure 2b self.left = Rearrange('...->...') # identity depthwise_stride = 1 else: # ShuffleUnit Figure 2c self.left = nn.AvgPool2d(kernel_size=3, stride=2, padding=1) depthwise_stride = 2 # ensure output of concat has the same channels as original output channels. out_channels -= in_channels assert out_channels > 0 self.right = nn.Sequential( # Use a 1x1 grouped or non-grouped convolution to reduce input channels # to bottleneck channels, as in a ResNet bottleneck module. conv1x1(in_channels, bottleneck_channels, groups=first_1x1_groups), nn.BatchNorm2d(bottleneck_channels), nn.ReLU(inplace=True), # channel shuffle Rearrange('b (c1 c2) h w -> b (c2 c1) h w', c1=groups), # 3x3 depthwise convolution followed by batch conv3x3(bottleneck_channels, bottleneck_channels, stride=depthwise_stride, groups=bottleneck_channels), nn.BatchNorm2d(bottleneck_channels), # Use 1x1 grouped convolution to expand from # bottleneck_channels to out_channels conv1x1(bottleneck_channels, out_channels, groups=groups), nn.BatchNorm2d(out_channels), ) def forward(self, x): if self.combine == 'add': combined = self.left(x) + self.right(x) else: combined = torch.cat([self.left(x), self.right(x)], dim=1) return F.relu(combined, inplace=True)

Rewriting the code helped to identify:

There is no sense in doing reshuffling and not using groups in the first (indeed, I in the paper it is not so). However, this is equivalent model. It is also strange that first convolution may be not grouped, while last convolution is always grouped (and that is different from the paper)

Other comments:

You've probably noticed that there is an identity layer for pytorch introduced here The last thing left is get rid of conv1x1 and conv3x3 in the code - those are not better than standard Simplifying ResNet class ResNetOld(nn.Module): def __init__(self, block, layers, num_classes=1000): self.inplanes = 64 super(ResNetOld, self).__init__() self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = nn.BatchNorm2d(64) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(block, 64, layers[0]) self.layer2 = self._make_layer(block, 128, layers[1], stride=2) self.layer3 = self._make_layer(block, 256, layers[2], stride=2) self.layer4 = self._make_layer(block, 512, layers[3], stride=2) self.avgpool = nn.AvgPool2d(7, stride=1) self.fc = nn.Linear(512 * block.expansion, num_classes) for m in self.modules(): if isinstance(m, nn.Conv2d): n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels m.weight.data.normal_(0, math.sqrt(2. / n)) elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() def _make_layer(self, block, planes, blocks, stride=1): downsample = None if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(planes * block.expansion), ) layers = [] layers.append(block(self.inplanes, planes, stride, downsample)) self.inplanes = planes * block.expansion for i in range(1, blocks): layers.append(block(self.inplanes, planes)) return nn.Sequential(*layers) def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.avgpool(x) x = x.view(x.size(0), -1) x = self.fc(x) return x def make_layer(inplanes, planes, block, n_blocks, stride=1): downsample = None if stride != 1 or inplanes != planes * block.expansion: # output size won't match input, so adjust residual downsample = nn.Sequential( nn.Conv2d(inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(planes * block.expansion), ) return nn.Sequential( block(inplanes, planes, stride, downsample), *[block(planes * block.expansion, planes) for _ in range(1, n_blocks)] ) def ResNetNew(block, layers, num_classes=1000): e = block.expansion resnet = nn.Sequential( Rearrange('b c h w -> b c h w', c=3, h=224, w=224), nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False), nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=3, stride=2, padding=1), make_layer(64, 64, block, layers[0], stride=1), make_layer(64 * e, 128, block, layers[1], stride=2), make_layer(128 * e, 256, block, layers[2], stride=2), make_layer(256 * e, 512, block, layers[3], stride=2), # combined AvgPool and view in one averaging operation Reduce('b c h w -> b c', 'mean'), nn.Linear(512 * e, num_classes), ) # initialization for m in resnet.modules(): if isinstance(m, nn.Conv2d): n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels m.weight.data.normal_(0, math.sqrt(2. / n)) elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() return resnet

Things that were changed

make_layer

Improving RNN language modelling class RNNOld(nn.Module): def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim, n_layers, bidirectional, dropout): super().__init__() self.embedding = nn.Embedding(vocab_size, embedding_dim) self.rnn = nn.LSTM(embedding_dim, hidden_dim, num_layers=n_layers, bidirectional=bidirectional, dropout=dropout) self.fc = nn.Linear(hidden_dim*2, output_dim) self.dropout = nn.Dropout(dropout) def forward(self, x): #x = [sent len, batch size] embedded = self.dropout(self.embedding(x)) #embedded = [sent len, batch size, emb dim] output, (hidden, cell) = self.rnn(embedded) #output = [sent len, batch size, hid dim * num directions] #hidden = [num layers * num directions, batch size, hid dim] #cell = [num layers * num directions, batch size, hid dim] #concat the final forward (hidden[-2,:,:]) and backward (hidden[-1,:,:]) hidden layers #and apply dropout hidden = self.dropout(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1)) #hidden = [batch size, hid dim * num directions] return self.fc(hidden.squeeze(0)) class RNNNew(nn.Module): def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim, n_layers, bidirectional, dropout): super().__init__() self.embedding = nn.Embedding(vocab_size, embedding_dim) self.rnn = nn.LSTM(embedding_dim, hidden_dim, num_layers=n_layers, bidirectional=bidirectional, dropout=dropout) self.dropout = nn.Dropout(dropout) self.directions = 2 if bidirectional else 1 self.fc = nn.Linear(hidden_dim * self.directions, output_dim) def forward(self, x): #x = [sent len, batch size] embedded = self.dropout(self.embedding(x)) #embedded = [sent len, batch size, emb dim] output, (hidden, cell) = self.rnn(embedded) hidden = rearrange(hidden, '(layer dir) b c -> layer b (dir c)', dir=self.directions) # take the final layer's hidden return self.fc(self.dropout(hidden[-1])) original code misbehaves for non-bidirectional models and fails when bidirectional = False, and there is only one layer modification of the code shows both how hidden is structured and how it is modified Writing FastText faster class FastTextOld(nn.Module): def __init__(self, vocab_size, embedding_dim, output_dim): super().__init__() self.embedding = nn.Embedding(vocab_size, embedding_dim) self.fc = nn.Linear(embedding_dim, output_dim) def forward(self, x): #x = [sent len, batch size] embedded = self.embedding(x) #embedded = [sent len, batch size, emb dim] embedded = embedded.permute(1, 0, 2) #embedded = [batch size, sent len, emb dim] pooled = F.avg_pool2d(embedded, (embedded.shape[1], 1)).squeeze(1) #pooled = [batch size, embedding_dim] return self.fc(pooled)

def FastTextNew(vocab_size, embedding_dim, output_dim): return nn.Sequential( Rearrange('t b -> t b'), nn.Embedding(vocab_size, embedding_dim), Reduce('t b c -> b c', 'mean'), nn.Linear(embedding_dim, output_dim), Rearrange('b c -> b c'), )

Some comments on new code:

Rearrange('b t -> t b'),

CNNs for text classification class CNNOld(nn.Module): def __init__(self, vocab_size, embedding_dim, n_filters, filter_sizes, output_dim, dropout): super().__init__() self.embedding = nn.Embedding(vocab_size, embedding_dim) self.conv_0 = nn.Conv2d(in_channels=1, out_channels=n_filters, kernel_size=(filter_sizes[0],embedding_dim)) self.conv_1 = nn.Conv2d(in_channels=1, out_channels=n_filters, kernel_size=(filter_sizes[1],embedding_dim)) self.conv_2 = nn.Conv2d(in_channels=1, out_channels=n_filters, kernel_size=(filter_sizes[2],embedding_dim)) self.fc = nn.Linear(len(filter_sizes)*n_filters, output_dim) self.dropout = nn.Dropout(dropout) def forward(self, x): #x = [sent len, batch size] x = x.permute(1, 0) #x = [batch size, sent len] embedded = self.embedding(x) #embedded = [batch size, sent len, emb dim] embedded = embedded.unsqueeze(1) #embedded = [batch size, 1, sent len, emb dim] conved_0 = F.relu(self.conv_0(embedded).squeeze(3)) conved_1 = F.relu(self.conv_1(embedded).squeeze(3)) conved_2 = F.relu(self.conv_2(embedded).squeeze(3)) #conv_n = [batch size, n_filters, sent len - filter_sizes[n]] pooled_0 = F.max_pool1d(conved_0, conved_0.shape[2]).squeeze(2) pooled_1 = F.max_pool1d(conved_1, conved_1.shape[2]).squeeze(2) pooled_2 = F.max_pool1d(conved_2, conved_2.shape[2]).squeeze(2) #pooled_n = [batch size, n_filters] cat = self.dropout(torch.cat((pooled_0, pooled_1, pooled_2), dim=1)) #cat = [batch size, n_filters * len(filter_sizes)] return self.fc(cat) class CNNNew(nn.Module): def __init__(self, vocab_size, embedding_dim, n_filters, filter_sizes, output_dim, dropout): super().__init__() self.embedding = nn.Embedding(vocab_size, embedding_dim) self.convs = nn.ModuleList([ nn.Conv1d(embedding_dim, n_filters, kernel_size=size) for size in filter_sizes ]) self.fc = nn.Linear(len(filter_sizes) * n_filters, output_dim) self.dropout = nn.Dropout(dropout) def forward(self, x): x = rearrange(x, 't b -> t b') emb = rearrange(self.embedding(x), 't b c -> b c t') pooled = [reduce(conv(emb), 'b c t -> b c', 'max') for conv in self.convs] concatenated = rearrange(pooled, 'filter b c -> b (filter c)') return self.fc(self.dropout(F.relu(concatenated))) Original code misuses Conv2d, while Conv1d is the right choice Fixed code can work with any number of filter_sizes (and won't fail) First line in new code does nothing, but was added for simplicity

Highway convolutions Highway convolutions are common in TTS systems. Code below makes splitting a bit more explicit. Splitting policy may eventually turn out to be important if input had previously groups over channel axes (group convolutions or bidirectional LSTMs/GRUs) Same applies to GLU and gated units in general

class HighwayConv1dOld(nn.Conv1d): def forward(self, inputs): L = super(HighwayConv1dOld, self).forward(inputs) H1, H2 = torch.chunk(L, 2, 1) # chunk at the feature dim torch.sigmoid_(H1) return H1 * H2 + (1.0 - H1) * inputs

class HighwayConv1dNew(nn.Conv1d): def forward(self, inputs): L = super().forward(inputs) H1, H2 = rearrange(L, 'b (split c) t -> split b c t', split=2) torch.sigmoid_(H1) return H1 * H2 + (1.0 - H1) * inputs

Tacotron's CBHG module class CBHG_Old(nn.Module): """CBHG module: a recurrent neural network composed of: - 1-d convolution banks - Highway networks + residual connections - Bidirectional gated recurrent units """ def __init__(self, in_dim, K=16, projections=[128, 128]): super(CBHG, self).__init__() self.in_dim = in_dim self.relu = nn.ReLU() self.conv1d_banks = nn.ModuleList( [BatchNormConv1d(in_dim, in_dim, kernel_size=k, stride=1, padding=k // 2, activation=self.relu) for k in range(1, K + 1)]) self.max_pool1d = nn.MaxPool1d(kernel_size=2, stride=1, padding=1) in_sizes = [K * in_dim] + projections[:-1] activations = [self.relu] * (len(projections) - 1) + [None] self.conv1d_projections = nn.ModuleList( [BatchNormConv1d(in_size, out_size, kernel_size=3, stride=1, padding=1, activation=ac) for (in_size, out_size, ac) in zip( in_sizes, projections, activations)]) self.pre_highway = nn.Linear(projections[-1], in_dim, bias=False) self.highways = nn.ModuleList( [Highway(in_dim, in_dim) for _ in range(4)]) self.gru = nn.GRU( in_dim, in_dim, 1, batch_first=True, bidirectional=True) def forward_old(self, inputs): # (B, T_in, in_dim) x = inputs # Needed to perform conv1d on time-axis # (B, in_dim, T_in) if x.size(-1) == self.in_dim: x = x.transpose(1, 2) T = x.size(-1) # (B, in_dim*K, T_in) # Concat conv1d bank outputs x = torch.cat([conv1d(x)[:, :, :T] for conv1d in self.conv1d_banks], dim=1) assert x.size(1) == self.in_dim * len(self.conv1d_banks) x = self.max_pool1d(x)[:, :, :T] for conv1d in self.conv1d_projections: x = conv1d(x) # (B, T_in, in_dim) # Back to the original shape x = x.transpose(1, 2) if x.size(-1) != self.in_dim: x = self.pre_highway(x) # Residual connection x += inputs for highway in self.highways: x = highway(x) # (B, T_in, in_dim*2) outputs, _ = self.gru(x) return outputs def forward_new(self, inputs, input_lengths=None): x = rearrange(inputs, 'b t c -> b c t') _, _, T = x.shape # Concat conv1d bank outputs x = rearrange([conv1d(x)[:, :, :T] for conv1d in self.conv1d_banks], 'bank b c t -> b (bank c) t', c=self.in_dim) x = self.max_pool1d(x)[:, :, :T] for conv1d in self.conv1d_projections: x = conv1d(x) x = rearrange(x, 'b c t -> b t c') if x.size(-1) != self.in_dim: x = self.pre_highway(x) # Residual connection x += inputs for highway in self.highways: x = highway(x) # (B, T_in, in_dim*2) outputs, _ = self.gru(self.highways(x)) return outputs

There is still a large room for improvements, but in this example only forward function was changed

Simple attention

Good news: there is no more need to guess order of dimensions. Neither for inputs nor for outputs

class Attention(nn.Module): def __init__(self): super(Attention, self).__init__() def forward(self, K, V, Q): A = torch.bmm(K.transpose(1,2), Q) / np.sqrt(Q.shape[1]) A = F.softmax(A, 1) R = torch.bmm(V, A) return torch.cat((R, Q), dim=1) def attention(K, V, Q): _, n_channels, _ = K.shape A = torch.einsum('bct,bcl->btl', [K, Q]) A = F.softmax(A * n_channels ** (-0.5), 1) R = torch.einsum('bct,btl->bcl', [V, A]) return torch.cat((R, Q), dim=1) Transformer's attention needs more attention class ScaledDotProductAttention(nn.Module): ''' Scaled Dot-Product Attention ''' def __init__(self, temperature, attn_dropout=0.1): super().__init__() self.temperature = temperature self.dropout = nn.Dropout(attn_dropout) self.softmax = nn.Softmax(dim=2) def forward(self, q, k, v, mask=None): attn = torch.bmm(q, k.transpose(1, 2)) attn = attn / self.temperature if mask is not None: attn = attn.masked_fill(mask, -np.inf) attn = self.softmax(attn) attn = self.dropout(attn) output = torch.bmm(attn, v) return output, attn class MultiHeadAttentionOld(nn.Module): ''' Multi-Head Attention module ''' def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1): super().__init__() self.n_head = n_head self.d_k = d_k self.d_v = d_v self.w_qs = nn.Linear(d_model, n_head * d_k) self.w_ks = nn.Linear(d_model, n_head * d_k) self.w_vs = nn.Linear(d_model, n_head * d_v) nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k))) nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k))) nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v))) self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5)) self.layer_norm = nn.LayerNorm(d_model) self.fc = nn.Linear(n_head * d_v, d_model) nn.init.xavier_normal_(self.fc.weight) self.dropout = nn.Dropout(dropout) def forward(self, q, k, v, mask=None): d_k, d_v, n_head = self.d_k, self.d_v, self.n_head sz_b, len_q, _ = q.size() sz_b, len_k, _ = k.size() sz_b, len_v, _ = v.size() residual = q q = self.w_qs(q).view(sz_b, len_q, n_head, d_k) k = self.w_ks(k).view(sz_b, len_k, n_head, d_k) v = self.w_vs(v).view(sz_b, len_v, n_head, d_v) q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x .. output, attn = self.attention(q, k, v, mask=mask) output = output.view(n_head, sz_b, len_q, d_v) output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) # b x lq x (n*dv) output = self.dropout(self.fc(output)) output = self.layer_norm(output + residual) return output, attn class MultiHeadAttentionNew(nn.Module): def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1): super().__init__() self.n_head = n_head self.w_qs = nn.Linear(d_model, n_head * d_k) self.w_ks = nn.Linear(d_model, n_head * d_k) self.w_vs = nn.Linear(d_model, n_head * d_v) nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k))) nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k))) nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v))) self.fc = nn.Linear(n_head * d_v, d_model) nn.init.xavier_normal_(self.fc.weight) self.dropout = nn.Dropout(p=dropout) self.layer_norm = nn.LayerNorm(d_model) def forward(self, q, k, v, mask=None): residual = q q = rearrange(self.w_qs(q), 'b l (head k) -> head b l k', head=self.n_head) k = rearrange(self.w_ks(k), 'b t (head k) -> head b t k', head=self.n_head) v = rearrange(self.w_vs(v), 'b t (head v) -> head b t v', head=self.n_head) attn = torch.einsum('hblk,hbtk->hblt', [q, k]) / np.sqrt(q.shape[-1]) if mask is not None: attn = attn.masked_fill(mask[None], -np.inf) attn = torch.softmax(attn, dim=3) output = torch.einsum('hblt,hbtv->hblv', [attn, v]) output = rearrange(output, 'head b l v -> b l (head v)') output = self.dropout(self.fc(output)) output = self.layer_norm(output + residual) return output, attn

Benefits of new implementation

we have one module, not two now code does not fail for None mask the amount of caveats in the original code that we removed is huge. Try erasing comments and deciphering what happens there Self-attention GANs

SAGANs are currently SotA for image generation, and can be simplified using same tricks. If torch.einsum supported non-one letter axes, we could improve this solution further.

class Self_Attn_Old(nn.Module): """ Self attention Layer""" def __init__(self,in_dim,activation): super(Self_Attn_Old,self).__init__() self.chanel_in = in_dim self.activation = activation self.query_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1) self.key_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1) self.value_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim , kernel_size= 1) self.gamma = nn.Parameter(torch.zeros(1)) self.softmax = nn.Softmax(dim=-1) # def forward(self, x): """ inputs : x : input feature maps( B X C X W X H) returns : out : self attention value + input feature attention: B X N X N (N is Width*Height) """ m_batchsize,C,width ,height = x.size() proj_query = self.query_conv(x).view(m_batchsize,-1,width*height).permute(0,2,1) # B X CX(N) proj_key = self.key_conv(x).view(m_batchsize,-1,width*height) # B X C x (*W*H) energy = torch.bmm(proj_query,proj_key) # transpose check attention = self.softmax(energy) # BX (N) X (N) proj_value = self.value_conv(x).view(m_batchsize,-1,width*height) # B X C X N out = torch.bmm(proj_value,attention.permute(0,2,1) ) out = out.view(m_batchsize,C,width,height) out = self.gamma*out + x return out,attention class Self_Attn_New(nn.Module): """ Self attention Layer""" def __init__(self, in_dim): super().__init__() self.query_conv = nn.Conv2d(in_dim, out_channels=in_dim//8, kernel_size=1) self.key_conv = nn.Conv2d(in_dim, out_channels=in_dim//8, kernel_size=1) self.value_conv = nn.Conv2d(in_dim, out_channels=in_dim, kernel_size=1) self.gamma = nn.Parameter(torch.zeros([1])) def forward(self, x): proj_query = rearrange(self.query_conv(x), 'b c h w -> b (h w) c') proj_key = rearrange(self.key_conv(x), 'b c h w -> b c (h w)') proj_value = rearrange(self.value_conv(x), 'b c h w -> b (h w) c') energy = torch.bmm(proj_query, proj_key) attention = F.softmax(energy, dim=2) out = torch.bmm(attention, proj_value) out = x + self.gamma * rearrange(out, 'b (h w) c -> b c h w', **parse_shape(x, 'b c h w')) return out, attention Improving time sequence prediction

While this example was considered to be simplistic, I had to analyze surrounding code to understand what kind of input was expected. You can try yourself.

One minor change done is now the code works with any dtype, not only double; and new code supports using GPU.

class SequencePredictionOld(nn.Module): def __init__(self): super(SequencePredictionOld, self).__init__() self.lstm1 = nn.LSTMCell(1, 51) self.lstm2 = nn.LSTMCell(51, 51) self.linear = nn.Linear(51, 1) def forward(self, input, future = 0): outputs = [] h_t = torch.zeros(input.size(0), 51, dtype=torch.double) c_t = torch.zeros(input.size(0), 51, dtype=torch.double) h_t2 = torch.zeros(input.size(0), 51, dtype=torch.double) c_t2 = torch.zeros(input.size(0), 51, dtype=torch.double) for i, input_t in enumerate(input.chunk(input.size(1), dim=1)): h_t, c_t = self.lstm1(input_t, (h_t, c_t)) h_t2, c_t2 = self.lstm2(h_t, (h_t2, c_t2)) output = self.linear(h_t2) outputs += [output] for i in range(future):# if we should predict the future h_t, c_t = self.lstm1(output, (h_t, c_t)) h_t2, c_t2 = self.lstm2(h_t, (h_t2, c_t2)) output = self.linear(h_t2) outputs += [output] outputs = torch.stack(outputs, 1).squeeze(2) return outputs class SequencePredictionNew(nn.Module): def __init__(self): super(SequencePredictionNew, self).__init__() self.lstm1 = nn.LSTMCell(1, 51) self.lstm2 = nn.LSTMCell(51, 51) self.linear = nn.Linear(51, 1) def forward(self, input, future=0): b, t = input.shape h_t, c_t, h_t2, c_t2 = torch.zeros(4, b, 51, dtype=self.linear.weight.dtype, device=self.linear.weight.device) outputs = [] for input_t in rearrange(input, 'b t -> t b ()'): h_t, c_t = self.lstm1(input_t, (h_t, c_t)) h_t2, c_t2 = self.lstm2(h_t, (h_t2, c_t2)) output = self.linear(h_t2) outputs += [output] for i in range(future): # if we should predict the future h_t, c_t = self.lstm1(output, (h_t, c_t)) h_t2, c_t2 = self.lstm2(h_t, (h_t2, c_t2)) output = self.linear(h_t2) outputs += [output] return rearrange(outputs, 't b () -> b t') Transforming spacial transformer network (STN) class SpacialTransformOld(nn.Module): def __init__(self): super(Net, self).__init__() # Spatial transformer localization-network self.localization = nn.Sequential( nn.Conv2d(1, 8, kernel_size=7), nn.MaxPool2d(2, stride=2), nn.ReLU(True), nn.Conv2d(8, 10, kernel_size=5), nn.MaxPool2d(2, stride=2), nn.ReLU(True) ) # Regressor for the 3 * 2 affine matrix self.fc_loc = nn.Sequential( nn.Linear(10 * 3 * 3, 32), nn.ReLU(True), nn.Linear(32, 3 * 2) ) # Initialize the weights/bias with identity transformation self.fc_loc[2].weight.data.zero_() self.fc_loc[2].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float)) # Spatial transformer network forward function def stn(self, x): xs = self.localization(x) xs = xs.view(-1, 10 * 3 * 3) theta = self.fc_loc(xs) theta = theta.view(-1, 2, 3) grid = F.affine_grid(theta, x.size()) x = F.grid_sample(x, grid) return x class SpacialTransformNew(nn.Module): def __init__(self): super(Net, self).__init__() # Spatial transformer localization-network linear = nn.Linear(32, 3 * 2) # Initialize the weights/bias with identity transformation linear.weight.data.zero_() linear.bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float)) self.compute_theta = nn.Sequential( nn.Conv2d(1, 8, kernel_size=7), nn.MaxPool2d(2, stride=2), nn.ReLU(True), nn.Conv2d(8, 10, kernel_size=5), nn.MaxPool2d(2, stride=2), nn.ReLU(True), Rearrange('b c h w -> b (c h w)', h=3, w=3), nn.Linear(10 * 3 * 3, 32), nn.ReLU(True), linear, Rearrange('b (row col) -> b row col', row=2, col=3), ) # Spatial transformer network forward function def stn(self, x): grid = F.affine_grid(self.compute_theta(x), x.size()) return F.grid_sample(x, grid) new code will give reasonable errors when passed image size is different from expected if batch size is divisible by 18, whatever you input in the old code, it'll fail no sooner than affine_grid. Improving GLOW

That's a good old depth-to-space written manually!

Since GLOW is revertible, it will frequently rely on rearrange -like operations.

def unsqueeze2d_old(input, factor=2): assert factor >= 1 and isinstance(factor, int) factor2 = factor ** 2 if factor == 1: return input size = input.size() B = size[0] C = size[1] H = size[2] W = size[3] assert C % (factor2) == 0, "{}".format(C) x = input.view(B, C // factor2, factor, factor, H, W) x = x.permute(0, 1, 4, 2, 5, 3).contiguous() x = x.view(B, C // (factor2), H * factor, W * factor) return x def squeeze2d_old(input, factor=2): assert factor >= 1 and isinstance(factor, int) if factor == 1: return input size = input.size() B = size[0] C = size[1] H = size[2] W = size[3] assert H % factor == 0 and W % factor == 0, "{}".format((H, W)) x = input.view(B, C, H // factor, factor, W // factor, factor) x = x.permute(0, 1, 3, 5, 2, 4).contiguous() x = x.view(B, C * factor * factor, H // factor, W // factor) return x

def unsqueeze2d_new(input, factor=2): return rearrange(input, 'b (c h2 w2) h w -> b c (h h2) (w w2)', h2=factor, w2=factor) def squeeze2d_new(input, factor=2): return rearrange(input, 'b c (h h2) (w w2) -> b (c h2 w2) h w', h2=factor, w2=factor)

term squeeze isn't very helpful: which dimension is squeezed? There is torch.squeeze , but it's very different. in fact, we could skip creating functions completely Detecting problems in YOLO detection def YOLO_prediction_old(input, num_classes, num_anchors, anchors, stride_h, stride_w): bs = input.size(0) in_h = input.size(2) in_w = input.size(3) scaled_anchors = [(a_w / stride_w, a_h / stride_h) for a_w, a_h in anchors] prediction = input.view(bs, num_anchors, 5 + num_classes, in_h, in_w).permute(0, 1, 3, 4, 2).contiguous() # Get outputs x = torch.sigmoid(prediction[..., 0]) # Center x y = torch.sigmoid(prediction[..., 1]) # Center y w = prediction[..., 2] # Width h = prediction[..., 3] # Height conf = torch.sigmoid(prediction[..., 4]) # Conf pred_cls = torch.sigmoid(prediction[..., 5:]) # Cls pred. FloatTensor = torch.cuda.FloatTensor if x.is_cuda else torch.FloatTensor LongTensor = torch.cuda.LongTensor if x.is_cuda else torch.LongTensor # Calculate offsets for each grid grid_x = torch.linspace(0, in_w - 1, in_w).repeat(in_w, 1).repeat( bs * num_anchors, 1, 1).view(x.shape).type(FloatTensor) grid_y = torch.linspace(0, in_h - 1, in_h).repeat(in_h, 1).t().repeat( bs * num_anchors, 1, 1).view(y.shape).type(FloatTensor) # Calculate anchor w, h anchor_w = FloatTensor(scaled_anchors).index_select(1, LongTensor([0])) anchor_h = FloatTensor(scaled_anchors).index_select(1, LongTensor([1])) anchor_w = anchor_w.repeat(bs, 1).repeat(1, 1, in_h * in_w).view(w.shape) anchor_h = anchor_h.repeat(bs, 1).repeat(1, 1, in_h * in_w).view(h.shape) # Add offset and scale with anchors pred_boxes = FloatTensor(prediction[..., :4].shape) pred_boxes[..., 0] = x.data + grid_x pred_boxes[..., 1] = y.data + grid_y pred_boxes[..., 2] = torch.exp(w.data) * anchor_w pred_boxes[..., 3] = torch.exp(h.data) * anchor_h # Results _scale = torch.Tensor([stride_w, stride_h] * 2).type(FloatTensor) output = torch.cat((pred_boxes.view(bs, -1, 4) * _scale, conf.view(bs, -1, 1), pred_cls.view(bs, -1, num_classes)), -1) return output def YOLO_prediction_new(input, num_classes, num_anchors, anchors, stride_h, stride_w): raw_predictions = rearrange(input, 'b (anchor prediction) h w -> prediction b anchor h w', anchor=num_anchors, prediction=5 + num_classes) anchors = torch.FloatTensor(anchors).to(input.device) anchor_sizes = rearrange(anchors, 'anchor dim -> dim () anchor () ()') _, _, _, in_h, in_w = raw_predictions.shape grid_h = rearrange(torch.arange(in_h).float(), 'h -> () () h ()').to(input.device) grid_w = rearrange(torch.arange(in_w).float(), 'w -> () () () w').to(input.device) predicted_bboxes = torch.zeros_like(raw_predictions) predicted_bboxes[0] = (raw_predictions[0].sigmoid() + grid_w) * stride_w # center x predicted_bboxes[1] = (raw_predictions[1].sigmoid() + grid_h) * stride_h # center y predicted_bboxes[2:4] = (raw_predictions[2:4].exp()) * anchor_sizes # bbox width and height predicted_bboxes[4] = raw_predictions[4].sigmoid() # confidence predicted_bboxes[5:] = raw_predictions[5:].sigmoid() # class predictions # merging all predicted bboxes for each image return rearrange(predicted_bboxes, 'prediction b anchor h w -> b (anchor h w) prediction')

We changed and fixed a lot:

new code won't fail if input is not on the first GPU old code has wrong grid_x and grid_y for non-square images new code doesn't use replication when broadcasting is sufficient old code strangely sometimes takes .data , but this has no real effect, as some branches preserve gradient till the end if gradients not needed, torch.no_grad should be used, so it's redundant Simpler output for a bunch of pictures

Next time you need to output drawings of you generative models, you can use this trick

device = 'cpu' plt.imshow(np.transpose(vutils.make_grid(fake_batch.to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0))) padded = F.pad(fake_batch[:64], [1, 1, 1, 1]) plt.imshow(rearrange(padded, '(b1 b2) c h w -> (b1 h) (b2 w) c', b1=8).cpu()) Instead of conclusion

Better code is a vague term; to be specific, things that are expected from code are:

reliable: does what expected and does not fail. Explicitly fails for wrong inputs readaility counts maintainable and modifiable reusable: understanding and modifying code should be easier than writing from scratch fast: in my measurements, proposed versions have speed similar to the original code

I've tried to demonstrate how you can improve these criteria for deep learning code. And einops helps you a lot.

Links pytorch and einops significant part of the code was taken from official examples and tutorials (references for other code are given in source of this html, if you're really curious) einops has a tutorial if you want a gentle introduction


Viewing all articles
Browse latest Browse all 9596

Trending Articles