Posit AI Blog: Implementing rotation equivariance: Group-equivariant CNN from scratch

0
608
Posit AI Blog: Implementing rotation equivariance: Group-equivariant CNN from scratch


Convolutional neural networks (CNNs) are nice – they’re in a position to detect options in a picture regardless of the place. Well, not precisely. They’re not detached to only any sort of motion. Shifting up or down, or left or proper, is ok; rotating round an axis is just not. That’s due to how convolution works: traverse by row, then traverse by column (or the opposite manner spherical). If we would like “more” (e.g., profitable detection of an upside-down object), we have to prolong convolution to an operation that’s rotation-equivariant. An operation that’s equivariant to some kind of motion won’t solely register the moved function per se, but additionally, preserve monitor of which concrete motion made it seem the place it’s.

This is the second publish in a collection that introduces group-equivariant CNNs (GCNNs). The first was a high-level introduction to why we’d need them, and the way they work. There, we launched the important thing participant, the symmetry group, which specifies what sorts of transformations are to be handled equivariantly. If you haven’t, please check out that publish first, since right here I’ll make use of terminology and ideas it launched.

Today, we code a easy GCNN from scratch. Code and presentation tightly observe a pocket book supplied as a part of University of Amsterdam’s 2022 Deep Learning Course. They can’t be thanked sufficient for making out there such glorious studying supplies.

In what follows, my intent is to clarify the overall considering, and the way the ensuing structure is constructed up from smaller modules, every of which is assigned a transparent function. For that motive, I gained’t reproduce all of the code right here; as a substitute, I’ll make use of the package deal gcnn. Its strategies are closely annotated; so to see some particulars, don’t hesitate to take a look at the code.

As of immediately, gcnn implements one symmetry group: (C_4), the one which serves as a operating instance all through publish one. It is straightforwardly extensible, although, making use of sophistication hierarchies all through.

Step 1: The symmetry group (C_4)

In coding a GCNN, the very first thing we have to present is an implementation of the symmetry group we’d like to make use of. Here, it’s (C_4), the four-element group that rotates by 90 levels.

We can ask gcnn to create one for us, and examine its parts.

# remotes::install_github("skeydan/gcnn")
library(gcnn)
library(torch)

C_4 <- CyclicGroup(order = 4)
elems <- C_4$parts()
elems
torch_tensor
 0.0000
 1.5708
 3.1416
 4.7124
[ CPUFloatType{4} ]

Elements are represented by their respective rotation angles: (0), (frac{pi}{2}), (pi), and (frac{3 pi}{2}).

Groups are conscious of the id, and know methods to assemble a component’s inverse:

C_4$id

g1 <- elems[2]
C_4$inverse(g1)
torch_tensor
 0
[ CPUFloatType{1} ]

torch_tensor
4.71239
[ CPUFloatType{} ]

Here, what we care about most is the group parts’ motion. Implementation-wise, we have to distinguish between them appearing on one another, and their motion on the vector house (mathbb{R}^2), the place our enter pictures stay. The former half is the simple one: It might merely be carried out by including angles. In truth, that is what gcnn does once we ask it to let g1 act on g2:

g2 <- elems[3]

# in C_4$left_action_on_H(), H stands for the symmetry group
C_4$left_action_on_H(torch_tensor(g1)$unsqueeze(1), torch_tensor(g2)$unsqueeze(1))
torch_tensor
 4.7124
[ CPUFloatType{1,1} ]

What’s with the unsqueeze()s? Since (C_4)’s final raison d’être is to be a part of a neural community, left_action_on_H() works with batches of parts, not scalar tensors.

Things are a bit much less easy the place the group motion on (mathbb{R}^2) is worried. Here, we’d like the idea of a group illustration. This is an concerned subject, which we gained’t go into right here. In our present context, it really works about like this: We have an enter sign, a tensor we’d wish to function on indirectly. (That “some way” will likely be convolution, as we’ll see quickly.) To render that operation group-equivariant, we first have the illustration apply the inverse group motion to the enter. That achieved, we go on with the operation as if nothing had occurred.

To give a concrete instance, let’s say the operation is a measurement. Imagine a runner, standing on the foot of some mountain path, able to run up the climb. We’d wish to document their peak. One choice we now have is to take the measurement, then allow them to run up. Our measurement will likely be as legitimate up the mountain because it was down right here. Alternatively, we is perhaps well mannered and never make them wait. Once they’re up there, we ask them to come back down, and after they’re again, we measure their peak. The outcome is similar: Body peak is equivariant (greater than that: invariant, even) to the motion of operating up or down. (Of course, peak is a fairly boring measure. But one thing extra fascinating, comparable to coronary heart charge, wouldn’t have labored so nicely on this instance.)

Returning to the implementation, it seems that group actions are encoded as matrices. There is one matrix for every group factor. For (C_4), the so-called common illustration is a rotation matrix:

[
begin{bmatrix} cos(theta) & -sin(theta) sin(theta) & cos(theta) end{bmatrix}
]

In gcnn, the operate making use of that matrix is left_action_on_R2(). Like its sibling, it’s designed to work with batches (of group parts in addition to (mathbb{R}^2) vectors). Technically, what it does is rotate the grid the picture is outlined on, after which, re-sample the picture. To make this extra concrete, that methodology’s code seems to be about as follows.

Here is a goat.

img_path <- system.file("imgs", "z.jpg", package deal = "gcnn")
img <- torchvision::base_loader(img_path) |> torchvision::transform_to_tensor()
img$permute(c(2, 3, 1)) |> as.array() |> as.raster() |> plot()

A goat sitting comfortably on a meadow.

First, we name C_4$left_action_on_R2() to rotate the grid.

# Grid form is [2, 1024, 1024], for a 2nd, 1024 x 1024 picture.
img_grid_R2 <- torch::torch_stack(torch::torch_meshgrid(
    checklist(
      torch::torch_linspace(-1, 1, dim(img)[2]),
      torch::torch_linspace(-1, 1, dim(img)[3])
    )
))

# Transform the picture grid with the matrix illustration of some group factor.
transformed_grid <- C_4$left_action_on_R2(C_4$inverse(g1)$unsqueeze(1), img_grid_R2)

Second, we re-sample the picture on the reworked grid. The goat now seems to be as much as the sky.

transformed_img <- torch::nnf_grid_sample(
  img$unsqueeze(1), transformed_grid,
  align_corners = TRUE, mode = "bilinear", padding_mode = "zeros"
)

transformed_img[1,..]$permute(c(2, 3, 1)) |> as.array() |> as.raster() |> plot()

Same goat, rotated up by 90 degrees.

Step 2: The lifting convolution

We need to make use of current, environment friendly torch performance as a lot as potential. Concretely, we need to use nn_conv2d(). What we’d like, although, is a convolution kernel that’s equivariant not simply to translation, but additionally to the motion of (C_4). This could be achieved by having one kernel for every potential rotation.

Implementing that concept is strictly what LiftingConvolution does. The precept is similar as earlier than: First, the grid is rotated, after which, the kernel (weight matrix) is re-sampled to the reworked grid.

Why, although, name this a lifting convolution? The ordinary convolution kernel operates on (mathbb{R}^2); whereas our prolonged model operates on combos of (mathbb{R}^2) and (C_4). In math communicate, it has been lifted to the semi-direct product (mathbb{R}^2rtimes C_4).

lifting_conv <- LiftingConvolution(
    group = CyclicGroup(order = 4),
    kernel_size = 5,
    in_channels = 3,
    out_channels = 8
  )

x <- torch::torch_randn(c(2, 3, 32, 32))
y <- lifting_conv(x)
y$form
[1]  2  8  4 28 28

Since, internally, LiftingConvolution makes use of a further dimension to understand the product of translations and rotations, the output is just not four-, however five-dimensional.

Step 3: Group convolutions

Now that we’re in “group-extended space”, we are able to chain quite a lot of layers the place each enter and output are group convolution layers. For instance:

group_conv <- GroupConvolution(
  group = CyclicGroup(order = 4),
    kernel_size = 5,
    in_channels = 8,
    out_channels = 16
)

z <- group_conv(y)
z$form
[1]  2 16  4 24 24

All that continues to be to be finished is package deal this up. That’s what gcnn::GroupEquivariantCNN() does.

Step 4: Group-equivariant CNN

We can name GroupEquivariantCNN() like so.

cnn <- GroupEquivariantCNN(
    group = CyclicGroup(order = 4),
    kernel_size = 5,
    in_channels = 1,
    out_channels = 1,
    num_hidden = 2, # variety of group convolutions
    hidden_channels = 16 # variety of channels per group conv layer
)

img <- torch::torch_randn(c(4, 1, 32, 32))
cnn(img)$form
[1] 4 1

At informal look, this GroupEquivariantCNN seems to be like several outdated CNN … weren’t it for the group argument.

Now, once we examine its output, we see that the extra dimension is gone. That’s as a result of after a sequence of group-to-group convolution layers, the module tasks right down to a illustration that, for every batch merchandise, retains channels solely. It thus averages not simply over places – as we usually do – however over the group dimension as nicely. A closing linear layer will then present the requested classifier output (of dimension out_channels).

And there we now have the entire structure. It is time for a real-world(ish) check.

Rotated digits!

The thought is to coach two convnets, a “normal” CNN and a group-equivariant one, on the standard MNIST coaching set. Then, each are evaluated on an augmented check set the place every picture is randomly rotated by a steady rotation between 0 and 360 levels. We don’t anticipate GroupEquivariantCNN to be “perfect” – not if we equip with (C_4) as a symmetry group. Strictly, with (C_4), equivariance extends over 4 positions solely. But we do hope it would carry out considerably higher than the shift-equivariant-only normal structure.

First, we put together the information; particularly, the augmented check set.

dir <- "/tmp/mnist"

train_ds <- torchvision::mnist_dataset(
  dir,
  obtain = TRUE,
  rework = torchvision::transform_to_tensor
)

test_ds <- torchvision::mnist_dataset(
  dir,
  practice = FALSE,
  rework = operate(x) >
      torchvision::transform_random_rotation(
        levels = c(0, 360),
        resample = 2,
        fill = 0
      )
  
)

train_dl <- dataloader(train_ds, batch_size = 128, shuffle = TRUE)
test_dl <- dataloader(test_ds, batch_size = 128)

How does it look?

test_images <- coro::acquire(
  test_dl, 1
)[[1]]$x[1:32, 1, , ] |> as.array()

par(mfrow = c(4, 8), mar = rep(0, 4), mai = rep(0, 4))
test_images |>
  purrr::array_tree(1) |>
  purrr::map(as.raster) |>
  purrr::iwalk(~ {
    plot(.x)
  })

32 digits, rotated randomly.

We first outline and practice a standard CNN. It is as just like GroupEquivariantCNN(), architecture-wise, as potential, and is given twice the variety of hidden channels, in order to have comparable capability total.

 default_cnn <- nn_module(
   "default_cnn",
   initialize = operate(kernel_size, in_channels, out_channels, num_hidden, hidden_channels) {
     self$conv1 <- torch::nn_conv2d(in_channels, hidden_channels, kernel_size)
     self$convs <- torch::nn_module_list()
     for (i in 1:num_hidden) {
       self$convs$append(torch::nn_conv2d(hidden_channels, hidden_channels, kernel_size))
     }
     self$avg_pool <- torch::nn_adaptive_avg_pool2d(1)
     self$final_linear <- torch::nn_linear(hidden_channels, out_channels)
   },
   ahead = operate(x) >
       self$final_linear()
     x
   
 )

fitted <- default_cnn |>
    luz::setup(
      loss = torch::nn_cross_entropy_loss(),
      optimizer = torch::optim_adam,
      metrics = checklist(
        luz::luz_metric_accuracy()
      )
    ) |>
    luz::set_hparams(
      kernel_size = 5,
      in_channels = 1,
      out_channels = 10,
      num_hidden = 4,
      hidden_channels = 32
    ) %>%
    luz::set_opt_hparams(lr = 1e-2, weight_decay = 1e-4) |>
    luz::match(train_dl, epochs = 10, valid_data = test_dl) 
Train metrics: Loss: 0.0498 - Acc: 0.9843
Valid metrics: Loss: 3.2445 - Acc: 0.4479

Unsurprisingly, accuracy on the check set is just not that nice.

Next, we practice the group-equivariant model.

fitted <- GroupEquivariantCNN |>
  luz::setup(
    loss = torch::nn_cross_entropy_loss(),
    optimizer = torch::optim_adam,
    metrics = checklist(
      luz::luz_metric_accuracy()
    )
  ) |>
  luz::set_hparams(
    group = CyclicGroup(order = 4),
    kernel_size = 5,
    in_channels = 1,
    out_channels = 10,
    num_hidden = 4,
    hidden_channels = 16
  ) |>
  luz::set_opt_hparams(lr = 1e-2, weight_decay = 1e-4) |>
  luz::match(train_dl, epochs = 10, valid_data = test_dl)
Train metrics: Loss: 0.1102 - Acc: 0.9667
Valid metrics: Loss: 0.4969 - Acc: 0.8549

For the group-equivariant CNN, accuracies on check and coaching units are quite a bit nearer. That is a pleasant outcome! Let’s wrap up immediately’s exploit resuming a thought from the primary, extra high-level publish.

A problem

Going again to the augmented check set, or fairly, the samples of digits displayed, we discover an issue. In row two, column 4, there’s a digit that “under normal circumstances”, needs to be a 9, however, most likely, is an upside-down 6. (To a human, what suggests that is the squiggle-like factor that appears to be discovered extra usually with sixes than with nines.) However, you could possibly ask: does this have to be an issue? Maybe the community simply must study the subtleties, the sorts of issues a human would spot?

The manner I view it, all of it will depend on the context: What actually needs to be achieved, and the way an software goes for use. With digits on a letter, I’d see no motive why a single digit ought to seem upside-down; accordingly, full rotation equivariance could be counter-productive. In a nutshell, we arrive on the identical canonical crucial advocates of honest, simply machine studying preserve reminding us of:

Always consider the way in which an software goes for use!

In our case, although, there’s one other side to this, a technical one. gcnn::GroupEquivariantCNN() is an easy wrapper, in that its layers all make use of the identical symmetry group. In precept, there is no such thing as a want to do that. With extra coding effort, completely different teams can be utilized relying on a layer’s place within the feature-detection hierarchy.

Here, let me lastly inform you why I selected the goat image. The goat is seen by a red-and-white fence, a sample – barely rotated, as a result of viewing angle – made up of squares (or edges, in case you like). Now, for such a fence, forms of rotation equivariance comparable to that encoded by (C_4) make a whole lot of sense. The goat itself, although, we’d fairly not have look as much as the sky, the way in which I illustrated (C_4) motion earlier than. Thus, what we’d do in a real-world image-classification process is use fairly versatile layers on the backside, and more and more restrained layers on the high of the hierarchy.

Thanks for studying!

Photo by Marjan Blan | @marjanblan on Unsplash

LEAVE A REPLY

Please enter your comment!
Please enter your name here