This is a PyTorch implementation of Batch-Channel Normalization from the paper Micro-Batch Training with Batch-Channel Normalization and Weight Standardization. We also have an annotated implementation of Weight Standardization.

Batch-Channel Normalization performs batch normalization followed by a channel normalization (similar to a Group Normalization. When the batch size is small a running mean and variance is used for batch normalization.

Here is the training code for training a VGG network that uses weight standardization to classify CIFAR-10 data.

```
27import torch
28from torch import nn
29
30from labml_helpers.module import Module
31from labml_nn.normalization.batch_norm import BatchNorm
```

This first performs a batch normalization - either normal batch norm or a batch norm with estimated mean and variance (exponential mean/variance over multiple batches). Then a channel normalization performed.

`34class BatchChannelNorm(Module):`

`channels`

is the number of features in the input`groups`

is the number of groups the features are divided into`eps`

is $ϵ$, used in $Var[x_{(k)}]+ϵ $ for numerical stability`momentum`

is the momentum in taking the exponential moving average`estimate`

is whether to use running mean and variance for batch norm

```
44 def __init__(self, channels: int, groups: int,
45 eps: float = 1e-5, momentum: float = 0.1, estimate: bool = True):
```

`53 super().__init__()`

Use estimated batch norm or normal batch norm.

```
56 if estimate:
57 self.batch_norm = EstimatedBatchNorm(channels,
58 eps=eps, momentum=momentum)
59 else:
60 self.batch_norm = BatchNorm(channels,
61 eps=eps, momentum=momentum)
```

Channel normalization

`64 self.channel_norm = ChannelNorm(channels, groups, eps)`

```
66 def forward(self, x):
67 x = self.batch_norm(x)
68 return self.channel_norm(x)
```

When input $X∈R_{B×C×H×W}$ is a batch of image representations, where $B$ is the batch size, $C$ is the number of channels, $H$ is the height and $W$ is the width. $γ∈R_{C}$ and $β∈R_{C}$.

$X˙_{⋅,C,⋅,⋅}=γ_{C}σ^_{C}X_{⋅,C,⋅,⋅}−μ^ _{C} +β_{C}$

where,

$μ^ _{C}σ^_{C} ⟵(1−r)μ^ _{C}+rBHW1 b,h,w∑ X_{b,c,h,w}⟵(1−r)σ^_{C}+rBHW1 b,h,w∑ (X_{b,c,h,w}−μ^ _{C})_{2} $are the running mean and variances. $r$ is the momentum for calculating the exponential mean.

`71class EstimatedBatchNorm(Module):`

`channels`

is the number of features in the input`eps`

is $ϵ$, used in $Var[x_{(k)}]+ϵ $ for numerical stability`momentum`

is the momentum in taking the exponential moving average`estimate`

is whether to use running mean and variance for batch norm

```
92 def __init__(self, channels: int,
93 eps: float = 1e-5, momentum: float = 0.1, affine: bool = True):
```

```
100 super().__init__()
101
102 self.eps = eps
103 self.momentum = momentum
104 self.affine = affine
105 self.channels = channels
```

Channel wise transformation parameters

```
108 if self.affine:
109 self.scale = nn.Parameter(torch.ones(channels))
110 self.shift = nn.Parameter(torch.zeros(channels))
```

Tensors for $μ^ _{C}$ and $σ^_{C}$

```
113 self.register_buffer('exp_mean', torch.zeros(channels))
114 self.register_buffer('exp_var', torch.ones(channels))
```

`x`

is a tensor of shape `[batch_size, channels, *]`

. `*`

denotes any number of (possibly 0) dimensions. For example, in an image (2D) convolution this will be `[batch_size, channels, height, width]`

`116 def forward(self, x: torch.Tensor):`

Keep old shape

`124 x_shape = x.shape`

Get the batch size

`126 batch_size = x_shape[0]`

Sanity check to make sure the number of features is correct

`129 assert self.channels == x.shape[1]`

Reshape into `[batch_size, channels, n]`

`132 x = x.view(batch_size, self.channels, -1)`

Update $μ^ _{C}$ and $σ^_{C}$ in training mode only

`135 if self.training:`

No backpropagation through $μ^ _{C}$ and $σ^_{C}$

`137 with torch.no_grad():`

Calculate the mean across first and last dimensions; $BHW1 b,h,w∑ X_{b,c,h,w}$

`140 mean = x.mean(dim=[0, 2])`

Calculate the squared mean across first and last dimensions; $BHW1 b,h,w∑ X_{b,c,h,w}$

`143 mean_x2 = (x ** 2).mean(dim=[0, 2])`

Variance for each feature $BHW1 b,h,w∑ (X_{b,c,h,w}−μ^ _{C})_{2}$

`146 var = mean_x2 - mean ** 2`

Update exponential moving averages

$μ^ _{C}σ^_{C} ⟵(1−r)μ^ _{C}+rBHW1 b,h,w∑ X_{b,c,h,w}⟵(1−r)σ^_{C}+rBHW1 b,h,w∑ (X_{b,c,h,w}−μ^ _{C})_{2} $

```
154 self.exp_mean = (1 - self.momentum) * self.exp_mean + self.momentum * mean
155 self.exp_var = (1 - self.momentum) * self.exp_var + self.momentum * var
```

Normalize $σ^_{C}X_{⋅,C,⋅,⋅}−μ^ _{C} $

`159 x_norm = (x - self.exp_mean.view(1, -1, 1)) / torch.sqrt(self.exp_var + self.eps).view(1, -1, 1)`

Scale and shift $γ_{C}σ^_{C}X_{⋅,C,⋅,⋅}−μ^ _{C} +β_{C}$

```
164 if self.affine:
165 x_norm = self.scale.view(1, -1, 1) * x_norm + self.shift.view(1, -1, 1)
```

Reshape to original and return

`168 return x_norm.view(x_shape)`

This is similar to Group Normalization but affine transform is done group wise.

`171class ChannelNorm(Module):`

`groups`

is the number of groups the features are divided into`channels`

is the number of features in the input`eps`

is $ϵ$, used in $Var[x_{(k)}]+ϵ $ for numerical stability`affine`

is whether to scale and shift the normalized value

```
178 def __init__(self, channels, groups,
179 eps: float = 1e-5, affine: bool = True):
```

```
186 super().__init__()
187 self.channels = channels
188 self.groups = groups
189 self.eps = eps
190 self.affine = affine
```

Parameters for affine transformation.

*Note that these transforms are per group, unlike in group norm where they are transformed channel-wise.*

```
195 if self.affine:
196 self.scale = nn.Parameter(torch.ones(groups))
197 self.shift = nn.Parameter(torch.zeros(groups))
```

`x`

is a tensor of shape `[batch_size, channels, *]`

. `*`

denotes any number of (possibly 0) dimensions. For example, in an image (2D) convolution this will be `[batch_size, channels, height, width]`

`199 def forward(self, x: torch.Tensor):`

Keep the original shape

`208 x_shape = x.shape`

Get the batch size

`210 batch_size = x_shape[0]`

Sanity check to make sure the number of features is the same

`212 assert self.channels == x.shape[1]`

Reshape into `[batch_size, groups, n]`

`215 x = x.view(batch_size, self.groups, -1)`

Calculate the mean across last dimension; i.e. the means for each sample and channel group $E[x_{(i_{N},i_{G})}]$

`219 mean = x.mean(dim=[-1], keepdim=True)`

Calculate the squared mean across last dimension; i.e. the means for each sample and channel group $E[x_{(i_{N},i_{G})}]$

`222 mean_x2 = (x ** 2).mean(dim=[-1], keepdim=True)`

Variance for each sample and feature group $Var[x_{(i_{N},i_{G})}]=E[x_{(i_{N},i_{G})}]−E[x_{(i_{N},i_{G})}]_{2}$

`225 var = mean_x2 - mean ** 2`

Normalize $x^_{(i_{N},i_{G})}=Var[x_{(i_{N},i_{G})}]+ϵ x_{(i_{N},i_{G})}−E[x_{(i_{N},i_{G})}] $

`230 x_norm = (x - mean) / torch.sqrt(var + self.eps)`

Scale and shift group-wise $y_{i_{G}}=γ_{i_{G}}x^_{i_{G}}+β_{i_{G}}$

```
234 if self.affine:
235 x_norm = self.scale.view(1, -1, 1) * x_norm + self.shift.view(1, -1, 1)
```

Reshape to original and return

`238 return x_norm.view(x_shape)`