Skip to content

Commit 6fc7fd5

Browse files
taiyacopybara-github
authored andcommitted
colab elements of cvxnet
PiperOrigin-RevId: 438883571
1 parent 3bcd105 commit 6fc7fd5

File tree

4 files changed

+320
-0
lines changed

4 files changed

+320
-0
lines changed
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
# Copyright 2020 The TensorFlow Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Example of conversion between convex hyperplanes and mesh."""
15+
16+
# --- being forgiving as this is a colab
17+
# pylint: disable=invalid-name
18+
# pylint: disable=redefined-outer-name
19+
# pylint: disable=missing-function-docstring
20+
# pylint: disable=using-constant-test
21+
22+
import os
23+
import shutil
24+
25+
import matplotlib.pyplot as plt
26+
from mpl_toolkits.mplot3d import Axes3D
27+
import numpy as np
28+
from scipy.spatial import HalfspaceIntersection
29+
import trimesh
30+
31+
32+
def load_npz(path):
33+
"""Load a halfspace definition from a numpy file."""
34+
data = np.load(path)
35+
trans = data['trans']
36+
planes = data['planes']
37+
T, C, H = planes.shape[0:3]
38+
return T, C, H, trans, planes
39+
40+
41+
def load_cube():
42+
"""A dummy halfspace definition of a cube."""
43+
T = 1 #< temporal
44+
C = 3 #< num convexes
45+
H = 6 #< num planes
46+
trans = np.zeros([T, C, 3])
47+
planes = np.zeros([T, C, H, 4])
48+
cube = np.array([[-1., 0., 0., -.1], [1., 0., 0., -.1], [0., -1., 0., -.1],
49+
[0., 1., 0., -.1], [0., 0., -1., -.1], [0., 0., 1., -.1]])
50+
planes[0, 0, ...] = cube
51+
planes[0, 1, ...] = cube
52+
planes[0, 2, ...] = cube
53+
trans[0, 0, ...] = np.array([0, 0, 0])
54+
trans[0, 1, ...] = np.array([+.25, 0, 0])
55+
trans[0, 2, ...] = np.array([-.25, 0, 0])
56+
return T, C, H, trans, planes
57+
58+
59+
T, C, H, trans, planes = load_npz('example.npz')
60+
T, C, H, trans, planes = load_cube()
61+
62+
63+
def halfspaces_to_vertices(halfspaces):
64+
# pre-condition: euclidean origin is within the halfspaces
65+
# Input: Hx(D+1) numpy array of halfplane constraints
66+
# Output: convex hull vertices
67+
n_dims = halfspaces.shape[1]-1
68+
feasible_point = np.zeros([n_dims,])
69+
hs = HalfspaceIntersection(halfspaces, feasible_point)
70+
return hs.intersections
71+
72+
73+
def vertices_to_convex(vertices):
74+
mesh = trimesh.Trimesh(vertices=vertices, faces=None)
75+
return mesh.convex_hull.vertices, mesh.convex_hull.faces
76+
77+
78+
def plot_wireframe(ax, vertices, triangles):
79+
xs, ys, zs = zip(*vertices)
80+
ax.plot_trisurf(
81+
xs,
82+
ys,
83+
zs,
84+
triangles=triangles,
85+
shade=True,
86+
linewidth=0.1,
87+
edgecolors=(0, 0, 0),
88+
antialiased=True)
89+
ax.set_xlim(-.5, .5)
90+
ax.set_ylim(-.5, .5)
91+
ax.set_zlim(-.5, .5)
92+
93+
94+
def makedirs(path):
95+
os.makedirs(path)
96+
97+
98+
def deletedirs(path):
99+
if os.path.isdir(path):
100+
shutil.rmtree(path)
101+
102+
103+
def export_mesh(path, vertices, faces):
104+
mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
105+
mesh.export(path)
106+
107+
108+
# --- Reset environment
109+
deletedirs('data/')
110+
fig = plt.figure()
111+
ax = Axes3D(fig)
112+
113+
114+
for iT in range(T):
115+
iT_folder = 'data/frame_{0:02d}/'.format(iT)
116+
makedirs(iT_folder)
117+
obj_shape = planes[iT, ...]
118+
obj_trans = trans[iT, ...]
119+
120+
combo = list()
121+
for iC in range(C):
122+
cvx_halfspaces = obj_shape[iC, ...]
123+
cvx_translation = obj_trans[iC, ...]
124+
vertices = halfspaces_to_vertices(cvx_halfspaces)
125+
vertices, faces = vertices_to_convex(vertices)
126+
vertices += cvx_translation
127+
plot_wireframe(ax, vertices, faces)
128+
129+
if True:
130+
iC_folder = iT_folder+'cvx_{0:02d}.obj'.format(iC)
131+
export_mesh(iC_folder, vertices, faces)
132+
plot_wireframe(ax, vertices, faces)
133+
134+
combo.append(trimesh.Trimesh(vertices=vertices, faces=faces))
135+
136+
if False:
137+
combo = np.sum(combo)
138+
combo.export(iT_folder+'/mesh.obj')
139+
plot_wireframe(ax, combo.vertices, combo.faces)
140+
141+
plt.show()
Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
# Copyright 2020 The TensorFlow Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""A simple 2D demo of the differentiable convex function."""
15+
16+
# --- being forgiving as this is a colab
17+
# pylint: skip-file
18+
19+
#%% Load the data (from point picker)
20+
import numpy as np
21+
import matplotlib.pyplot as plt
22+
import numpy as np
23+
import matplotlib.image as mpimg
24+
25+
# --- Equations of hyperplanes in the 'hyperplanes.png' image
26+
h0 = np.array([(223.84848484848487, 55.04545454545456),
27+
(97.78787878787875, 91.4848484848485)])
28+
h1 = np.array([(96.80303030303028, 91.4848484848485),
29+
(62.333333333333286, 239.21212121212125)])
30+
h2 = np.array([(62.333333333333286, 239.21212121212125),
31+
(134.2272727272727, 311.1060606060606)])
32+
h3 = np.array([(134.2272727272727, 311.1060606060606),
33+
(264.22727272727275, 161.40909090909093)])
34+
h4 = np.array([(264.22727272727275, 161.40909090909093),
35+
(223.84848484848487, 55.04545454545456)])
36+
h5 = np.array([(234.6818181818182, 327.8484848484849),
37+
(333.1666666666667, 159.43939393939394)])
38+
hs = [h0, h1, h2, h3, h4, h5]
39+
40+
# --- Load base image
41+
img = mpimg.imread('hyperplanes.png')
42+
43+
if False:
44+
#--- Check lines match PNG
45+
plt.figure(0)
46+
imgplot = plt.imshow(img)
47+
48+
def ploth(h):
49+
plt.plot(h[0][0], h[0][1], '.r')
50+
plt.plot(h[1][0], h[1][1], '.r')
51+
52+
for h in hs:
53+
ploth(h)
54+
55+
56+
def pointnormal(h):
57+
ROT = np.array([[0, -1], [1, 0]])
58+
p1 = np.array(h[0][:])
59+
p2 = np.array(h[1][:])
60+
n = (p2 - p1) / np.linalg.norm(p2 - p1)
61+
return p1, np.dot(ROT, n)
62+
63+
64+
#--- Define sampling domain
65+
x = np.linspace(0, 364, 364)
66+
y = np.linspace(0, 364, 364)
67+
XX, YY = np.meshgrid(x, y)
68+
69+
#--- Compute the SDFs
70+
D = np.zeros((len(hs), img.shape[0], img.shape[1]))
71+
for i, hi in enumerate(hs):
72+
p0, n0 = pointnormal(hi)
73+
XY = np.stack([XX, YY])
74+
p0 = np.reshape(p0, [2, 1, 1]) # (2,1,1)
75+
n0 = np.reshape(n0, [2, 1, 1])
76+
off = (XY - p0) #< broadcat (2,W,H)
77+
d = np.linalg.norm(off, axis=0)
78+
d = np.einsum('i...,i...', n0, off)
79+
D[i, ...] = d
80+
81+
# softmax = lambda x, delta: np.exp(delta*x) / np.sum(np.exp(delta*x), axis=0)
82+
softmax = lambda x, delta=1: np.log(np.sum(np.exp(delta * x), axis=0)) / delta
83+
Dmax = softmax(D)
84+
85+
# Dmax = D.max(axis=0)
86+
D_clim = np.maximum(D.max(), -D.min())
87+
Dmax_clim = np.maximum(Dmax.max(), -Dmax.min())
88+
Dshift = Dmax #< ?what was this?
89+
90+
sigmoid = lambda x, sigma: 1 / (1 + np.exp(sigma * x))
91+
Dout = sigmoid(Dshift, 1 / 10.)
92+
93+
#%%
94+
#--- individual
95+
get_ipython().system('mkdir cvxdec')
96+
for i, hi in enumerate(hs):
97+
d = D[i, ...]
98+
plt.figure(i)
99+
plt.imshow(d, cmap=plt.get_cmap('coolwarm'), clim=(-D_clim, +D_clim))
100+
plt.contour(d, [0])
101+
plt.axis('off')
102+
plt.gca().xaxis.set_major_locator(plt.NullLocator())
103+
plt.gca().yaxis.set_major_locator(plt.NullLocator())
104+
plt.savefig(
105+
'cvxdec/sdf_{}.png'.format(i), bbox_inches='tight', pad_inches=-.1)
106+
107+
#%%
108+
#--- Display a single one + the colormap beside it
109+
for i, hi in enumerate(hs):
110+
d = D[i, ...]
111+
plt.figure(i)
112+
imaxis = plt.imshow(d, cmap=plt.get_cmap('coolwarm'), clim=(-D_clim, +D_clim))
113+
plt.contour(d, [0])
114+
plt.gcf().colorbar(imaxis)
115+
plt.axis('off')
116+
plt.gca().xaxis.set_major_locator(plt.NullLocator())
117+
plt.gca().yaxis.set_major_locator(plt.NullLocator())
118+
plt.savefig(
119+
'cvxdec/sdf_{}_cmap.png'.format(i), bbox_inches='tight', pad_inches=-.1)
120+
break
121+
122+
#%%
123+
#--- max / union
124+
plt.figure()
125+
imaxis = plt.imshow(
126+
Dmax, cmap=plt.get_cmap('coolwarm'), clim=(-Dmax_clim, +Dmax_clim))
127+
plt.contour(Dmax, [0])
128+
plt.axis('off')
129+
plt.gcf().colorbar(imaxis)
130+
plt.gca().xaxis.set_major_locator(plt.NullLocator())
131+
plt.gca().yaxis.set_major_locator(plt.NullLocator())
132+
plt.savefig('cvxdec/maxoperator_cmap.png', bbox_inches='tight', pad_inches=0)
133+
plt.show()
134+
135+
#%%
136+
#--- max / union with different thresholds
137+
Dmax_news = list()
138+
for idelta, delta in enumerate([0.040, 0.060, 0.080, 1]):
139+
Dmax_new = softmax(D, delta)
140+
Dmax_news.append(Dmax_new)
141+
if True:
142+
print(delta)
143+
plt.figure(idelta, frameon=False)
144+
imaxis = plt.imshow(
145+
Dmax, cmap=plt.get_cmap('coolwarm'), clim=(-Dmax_clim, +Dmax_clim))
146+
plt.contour(Dmax_new, [0])
147+
# plt.gcf().colorbar(imaxis)
148+
plt.axis('off')
149+
plt.gca().xaxis.set_major_locator(plt.NullLocator())
150+
plt.gca().yaxis.set_major_locator(plt.NullLocator())
151+
plt.savefig(
152+
'cvxdec/softmax_{}.png'.format(delta),
153+
bbox_inches='tight',
154+
pad_inches=-.1)
155+
plt.show()
156+
157+
#%%
158+
#--- sigmoid
159+
Dshift = Dmax_news[2]
160+
for isigma, sigma in enumerate([1 / 5]):
161+
Dout = sigmoid(Dshift, sigma)
162+
if True:
163+
plt.figure(isigma, frameon=False)
164+
imaxis = plt.imshow(Dout, cmap=plt.get_cmap('coolwarm'), clim=(0, 1))
165+
plt.contour(Dout, [0.5])
166+
plt.axis('off')
167+
# plt.gcf().colorbar(imaxis)
168+
plt.gca().xaxis.set_major_locator(plt.NullLocator())
169+
plt.gca().yaxis.set_major_locator(plt.NullLocator())
170+
plt.savefig(
171+
'cvxdec/sigmoid_{}.png'.format(sigma),
172+
bbox_inches='tight',
173+
pad_inches=-.1)
174+
plt.show()
175+
176+
#%%
177+
#--- 2D visualization
178+
plt.figure()
179+
plt.plot(Dout[182, :])
8.86 KB
Binary file not shown.
37.8 KB
Loading

0 commit comments

Comments
 (0)