In [None]:
import numpy as np
from matplotlib import pyplot as plt
import torch

In [None]:
def heron(V, F):
    # V : vertex * axis(x,y,z) -> position
    # F : face * vertex(0,1,2) -> index into V
    tris = V[F] # face * vertex * position
    dels = tris.roll(-1, dims=1) - tris
    lens = (dels ** 2).sum(dim=2).sqrt()
    semi = lens.sum(dim=1, keepdim=True) / 2
    abcs = (semi - lens).prod(dim=1, keepdim=True)
    area = (semi * abcs).sqrt()
    return area.sum()

from mpl_toolkits.mplot3d import Axes3D
def draw_it(V, F, name=None, azim=45):
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    ax.plot_trisurf(V[:, 0], V[:, 1], V[:, 2], triangles=F)
    ax.set_zlim(0, 1)
    ax.view_init(azim=azim)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_zticks([])
    if name is not None:
        fig.savefig(name)
    plt.close(fig)

In [None]:
u = np.linspace(0, 1, 15, endpoint=True)
v = np.linspace(0, 1, 15, endpoint=True)

u, v = np.meshgrid(u, v)
u, v = u.flatten(), v.flatten()
E = torch.tensor(np.argwhere((u == 0) | (u == 1) | (v == 0) | (v == 1)).flatten())
num_E = len(E)

def case_1():
    x = (torch.tensor(u, dtype=torch.float32) - 0.5) * 2
    y = (torch.tensor(v, dtype=torch.float32) - 0.5) * 2
    z = 0. * x
    return x, y, z

def case_2():
    x = (torch.tensor(u, dtype=torch.float32) - 0.5) * 2 / (2 ** 0.5)
    y = (torch.tensor(v, dtype=torch.float32) - 0.5) * 2
    z = (1./1.41 - x.abs())
    return x, y, z

def case_3a():
    x = (torch.tensor(u, dtype=torch.float32) - 0.5) * 2
    y = (torch.tensor(v, dtype=torch.float32) - 0.5) * 2
    z = 0. * x
    z[x.abs() == 1.] = 1.
    return x, y, z

def case_3b():
    x = (torch.tensor(u, dtype=torch.float32) - 0.5) * 2
    y = (torch.tensor(v, dtype=torch.float32) - 0.5) * 2
    z = 0. * x
    z[x.abs() == 1.] = (1. - y.abs())[x.abs() == 1.] / (2 ** 0.5)
    z[y       == 1.] = (1. - x.abs())[y       == 1.] / (2 ** 0.5)
    x = x * (y + 1.1) / 2
    return x, y, z

In [None]:
x, y, z = case_1()
V = torch.stack([x, y, z]).t()

import matplotlib.tri as mtri
tri = mtri.Triangulation(u, v)
F = tri.triangles

import os
os.system('rm *.jpg')
for i in range(30000):
    V.requires_grad_()
    loss = heron(V, F)
    loss.backward()
    V.grad[E] = 0.
    V = (V - V.grad * 0.01).detach()
    if i % 500 == 0:
        draw_it(V, F, '%05d.jpg' % i, azim=45 + i * 180 // 30000)
        print(loss)

os.system('convert -loop 0 *.jpg case_1.gif; rm *.jpg')