In [None]:
import numpy as np
from matplotlib import pyplot as plt

import torch
import torch.nn.functional as F

import skimage.transform
import skimage.data
import skimage.io

def show_img(img):
    plt.figure()
    plt.colorbar(plt.imshow(img))
    
def imsave(fname, img):
    img = 1 - img[0, 0]
    rgba = torch.zeros_like(img).unsqueeze(2).repeat(1, 1, 3)
    rgba[:, :, 1] = img
    skimage.io.imsave(fname, rgba)

In [None]:
N = 1024 # "radius" of the target image

# get target image from skimage builtin
# presumably any silhouette would work...
T = skimage.data.horse() * 1
T = T / T.max()
T = skimage.transform.resize(T, (N * 2 + 1, N * 2 + 1))

In [None]:
# generate a "grid" by superimpsing vertical and horizontal sinusoids
x, y = torch.meshgrid(torch.tensor(range(-N, N+1)), torch.tensor(range(-N, N+1)))
base = (torch.stack([y, x]) * 1. / N).permute(1, 2, 0).unsqueeze(0)
vert = ((base[:, :, :, 0] * N).sin() + 1) / 2
horz = ((base[:, :, :, 1] * N).sin() + 1) / 2
# "sharpen" the sinusoids to resemble a fine mesh
img = 1 - (1 - vert * horz) ** 8
img = img.unsqueeze(1)

# given a low-resolution grid of offsets, produce the moire pattern
def render(offs, verbose=False):
    # clamp perturbations to be less than 1% of the image extent
    offs = offs.sigmoid() / 100
    # bicubically upsample small offsets to full scale
    offs = F.interpolate(
        offs,
        size=base.shape[-3:-1],
        mode='bicubic',
        align_corners=True
    ).permute(0, 2, 3, 1)
    
    # warp grid based on interpolated offsets
    grid = base + offs
    warped = F.grid_sample(img, grid, align_corners=True)
    if verbose:
        imsave('warp.png', warped)
        imsave('screen.png', img)
    
    # overlay on unperturbed grid
    spp = warped * img
    return spp

In [None]:
# our low-resolution offsets are 64x64
K = 64
a = torch.randn((1, 2, K, K))

i = 0
while True:
    i += 1
    a.requires_grad_()
    out = render(a)
    loss = ((out - torch.tensor(T))**2).mean()
    if i % 100 == 1:
        plt.imsave('out-%05d.png' % i, out.detach()[0, 0])
        print(loss)
    loss.backward()
    a = a.detach() - K * K * a.grad.detach()

In [None]:
render(a.detach(), True)