(ICRA2021)ScrewNet

ScrewNet论文阅读和源码分析。

核心表述

​ 其实ScrewNet的目的很简单,希望从深度图中直接估计出物体的关节模型及其位形信息,而之前的工作都需要引入额外的关节体的纹理信息、或者指定关节体的类型(rigid、revolute、prismatic,helical)。

​ ScrewNet的核心结构如下:

​ 对于N帧深度图,每一帧都先使用ResNet-18作为backbone来提取2D图像特征,然后提取出来的N个特征向量传入到LSTM层中计算,然后输出N-1 X 8个相对的Screw Parameter。此处提供一个复习LSTM的博客

Screw Theory

​ 这篇文章能中ICRA的点就在于此。“空间中任何一个物体的位移都可以通过绕着一条直线的旋转以及绕着这条线的平行移动解决。”这条线叫做screw axis of displacement $S$。在普吕克坐标系下,直线可以表示成$(\textbf{l}, \textbf{m})$的形式,并且满足$||\textbf{l}||=1$和$\textbf{l}\cdot\textbf{m}=0$这两个约束条件。其实也很容易理解,在欧式坐标下,我们取直线的一段线段,有端点x和y。我们令
$$
\textbf{l}=\text{normalize} (y-x) \
\textbf{m}=x\times y
$$
​ 这就是普吕克坐标下表示直线的方法。

​ 上图中的$(d,m)$就是我们的$(\textbf{l}, \textbf{m})$。

​ 所以ScrewNet就使用了$(\textbf{l}, \textbf{m}, \theta, d)$来表示在SE(3)中的刚体运动。其中$d$是沿着轴的线性平移,而$\theta$是绕着轴的旋转,并且满足$d=h\theta$。

Loss Function

​ Screw位移包括了两部分:screw轴$S$,以及对应的位形$q_i$。所以ScrewNet希望同时优化如下的多个目标损失:
$$
L=\lambda_1L_{S_{ori}}+\lambda_2L_{S_{dist}}+\lambda_3L_{S_{cons}}+\lambda_4{L_q}
$$
​ 其中$L_{S_{ori}}$惩罚的是screw轴的偏差,所以通过GT轴和screw轴的角度偏差来计算,而$L_{S_{dist}}$是惩罚的是预测的screw轴和GT轴的空间距离,通过普吕克坐标系下的直线距离来表示。即:

$$
d((\textbf{l}_1,\textbf{m}_1),(\textbf{l}_2,\textbf{m}_2))=
\begin{cases}
0, & \text{if $\textbf{l}_1$ and $\textbf{l}_2$ intersect} \\
||\textbf{l}_1\times(\textbf{m}_1-\textbf{m}_2)||, & \text{else if $\textbf{l}_1$ and $\textbf{l}_2$ are parallel, i.e.$||\textbf{l}_1\times \textbf{l}_2||=0$} \\
\frac{|\textbf{l}_1\cdot \textbf{m}_2+\textbf{l}_2 \cdot \textbf{m}_1|}{||\textbf{l}_1\times \textbf{l}_2||}, & else,\textbf{l}_1\ and\ \textbf{l}_2\ are\ skew\ lines
\end{cases}\\
L_{S_{dist}}=d((\textbf{l}_{GT},\textbf{m}_{GT}),(\textbf{l}_{pred},\textbf{m}_{pred}))
$$
​ 而$L_{S_{cons}}$强迫其满足预测出来的直线满足普吕克约束$(\textbf{l}\cdot\textbf{m}=0)$和$||\textbf{l}||=1$;$L_q$是位形损失。

​ 其中$L_q$可以由两部分组成:旋转误差$L_{\theta}$和平移误差$L_d$,如下计算:
$$
L_q=\alpha_1L_{\theta}+\alpha_2L_d \\
L_{\theta}=I_{3\times3}-R(\theta_{GT},\textbf{l}_{GT})R(\theta_{pred},\textbf{l}_{pred})^T \\
L_d=||d_{GT}\cdot\textbf{l}_{GT}-d_{pred}\cdot\textbf{l}_{pred}||
$$
​ 其中的$R(\theta,\textbf{l})$就是沿着轴$\textbf{l}$旋转$\theta$角度的旋转矩阵$R$。之所以不直接对$q_{GT}$和$q_{pred}$施加$L_2$损失是因为这个损失函数的构成确保了其物理的含义,因为这个损失函数是基于旋转矩阵正交的性质而设计的,所以它可以确保学出来的$\theta_{pred}$和$\textbf{l}_{pred}$是满足$R(\theta_{pred},\textbf{l}_{pred})\in SO(3)$。类似的,损失函数$L_d$也计算了沿着两根不同的轴$\textbf{l}_{GT}$和$\textbf{l}_{pred}$的平移误差,如果我们只是计算$d_{GT}$和$d_{pred}$的范数的话,就等于我们默认了它们是沿着同一个轴平移的,这就不合理。

​ 综上所述,我们Loss函数的选取遵循了我们所提出的Screw理论。

打标签方法

​ ScrewNet的训练集包括了一系列的深度图像,并且需要有对应的screw displacement。使用Mujoco来渲染仿真中的关节体并且记录深度图像。使用了数据集中的柜子、抽屉、微波炉、烤箱等。

​ 为了创建screw displacement标签,我们考虑$o_i$作为基物体,然后我们计算后面的$o_j$相对于基物体的相对screw displacement。具体来说,就是给定一个N帧图片的视频流$I_{1:N}$,我们首先选定视频的第一帧是物体的基础位姿,然后计算出n-1帧的相对的screw displacement。

​ 也就是我们有相对于坐标系$F_{O_j^1}$的n-1个位移了,我们可以通过在普吕克坐标下做变换把这n-1个相对位移全部转换到相对于基坐标轴$O_i$下。具体的普吕克坐标系下的变换形式可以参考原文。

代码实现

1.models.py

​ 在代码中,ScrewNet提供了三个模型:ScrewNet、ScrewNet_2imgs、ScrewNet_NoLSTM,在实际训练和测试中,和数据集的对应关系如下:

1
2
3
4
5
6
7
8
9
10
11
12
if args.model_type == '2imgs':
print("Testing Model: ScrewNet_2imgs")
best_model = ScrewNet_2imgs(n_output=8)
testset = RigidTransformDataset(args.ntest, args.test_dir)
elif args.model_type == 'noLSTM':
print("Testing Model: ScrewNet_noLSTM")
best_model = ScrewNet_NoLSTM(seq_len=16, fc_replace_lstm_dim=1000, n_output=8)
testset = ArticulationDataset(args.ntest, args.test_dir)
else:
print("Testing ScrewNet")
best_model = ScrewNet(lstm_hidden_dim=1000, n_lstm_hidden_layers=1, n_output=8)
testset = ArticulationDataset(args.ntest, args.test_dir)

​ 其中ScrewNet_2imgs似乎是一个降级版本,我们先从这个模型的源码开始读起:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
class ScrewNet_2imgs(nn.Module):
def __init__(self, n_output=8):
super(ScrewNet_2imgs, self).__init__()

self.fc_mlp_dim_1 = 2000
self.fc_mlp_dim_2 = 512
self.fc_mlp_dim_3 = 256
self.n_output = n_output

self.resnet = models.resnet18()
self.bn_res_1 = nn.BatchNorm1d(1000, momentum=0.01) #需要归一化的维度为1000

self.fc_mlp_1 = nn.Linear(self.fc_mlp_dim_1, self.fc_mlp_dim_1)
self.bn_mlp_1 = nn.BatchNorm1d(self.fc_mlp_dim_1, momentum=0.01)
self.fc_mlp_2 = nn.Linear(self.fc_mlp_dim_1, self.fc_mlp_dim_2)
self.bn_mlp_2 = nn.BatchNorm1d(self.fc_mlp_dim_2, momentum=0.01)
self.fc_mlp_3 = nn.Linear(self.fc_mlp_dim_2, self.fc_mlp_dim_3)
self.bn_mlp_3 = nn.BatchNorm1d(self.fc_mlp_dim_3, momentum=0.01)
self.fc_mlp_4 = nn.Linear(self.fc_mlp_dim_3, self.n_output)

def forward(self, X_3d):
# X shape: Batch x Sequence x 3 Channels x img_dims
# Run resnet sequentially on the data to generate embedding sequence
cnn_embed_seq = []
for t in range(X_3d.size(1)): #从第一帧开始枚举
x = self.resnet(X_3d[:, t, :, :, :]) #resnet的输入B * 1 * 3 * W * H
x = x.view(x.size(0), -1) #拉伸为B * vector_size(每一个vector为CNN隐变量)
x = self.bn_res_1(x) #归一化为B * 1000
cnn_embed_seq.append(x)

# 此时我们得到了cnn_embed_seq是大小为N的一个list,其中每个元素为B * 1000的格式
# 首先我们把它变为torch.tensor,并且交换sample dim和time dim
cnn_embed_seq = torch.stack(cnn_embed_seq, dim=0).transpose_(0, 1)
# 此时我们有cnn_embed_seq = (B * N * 1000)
# 因为在transpose后,虽然tensor的shape改变了,但是它在内存中的存储位置并没有改变,如果我们直接调用view会出错,所以我们需要先做.contiguous(),然后再使用view
x_rnn = cnn_embed_seq.contiguous().view(-1, self.fc_mlp_dim_1)
# 注意到我们此时view成了(B * 2000),所以我们可以反推原先输入的N=2,所以这个模型的dataset都是2帧的视频
# FC layers
x_rnn = self.fc_mlp_1(x_rnn) # B * 2000 => B * 2000
x_rnn = F.relu(x_rnn)
x_rnn = self.bn_mlp_1(x_rnn)
x_rnn = self.fc_mlp_2(x_rnn) # B * 2000 => B * 512
x_rnn = F.relu(x_rnn)
x_rnn = self.bn_mlp_2(x_rnn)
x_rnn = self.fc_mlp_3(x_rnn) # B * 512 => B * 256
x_rnn = F.relu(x_rnn)
x_rnn = self.bn_mlp_3(x_rnn)
x_rnn = self.fc_mlp_4(x_rnn) # B * 256 => B * 8(其中8维就是screw parameter)
return x_rnn.view(X_3d.size(0), -1) #返回 B * 8

​ 其中的models.resnet18()其实是torchvision.models.resnet.py中帮我们实现好的resnet,我们可以单纯地认为输入B * 3 * W * H,输出一个B * 1000的ResNet特征。

​ 接下来是No_LSTM版本的ScrewNet:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
class ScrewNet(nn.Module):
def __init__(self, lstm_hidden_dim=1000, n_lstm_hidden_layers=1, drop_p=0.5, n_output=8):
super(ScrewNet, self).__init__()

self.fc_res_dim_1 = 512
self.lstm_input_dim = 1000
self.lstm_hidden_dim = lstm_hidden_dim
self.n_lstm_hidden_layers = n_lstm_hidden_layers
self.fc_lstm_dim_1 = 256
self.fc_lstm_dim_2 = 128
self.n_output = n_output
self.drop_p = drop_p

self.resnet = models.resnet18()
self.fc_res_1 = nn.Linear(self.lstm_input_dim, self.fc_res_dim_1)
self.bn_res_1 = nn.BatchNorm1d(self.fc_res_dim_1, momentum=0.01)
self.fc_res_2 = nn.Linear(self.fc_res_dim_1, self.lstm_input_dim)

self.LSTM = nn.LSTM(
input_size=self.lstm_input_dim,
hidden_size=self.lstm_hidden_dim,
num_layers=self.n_lstm_hidden_layers,
batch_first=True,
)

self.fc_lstm_1 = nn.Linear(self.lstm_hidden_dim, self.fc_lstm_dim_1)
self.bn_lstm_1 = nn.BatchNorm1d(self.fc_lstm_dim_1, momentum=0.01)
self.fc_lstm_2 = nn.Linear(self.fc_lstm_dim_1, self.fc_lstm_dim_2)
self.bn_lstm_2 = nn.BatchNorm1d(self.fc_lstm_dim_2, momentum=0.01)
self.dropout_layer1 = nn.Dropout(p=self.drop_p)
self.fc_lstm_3 = nn.Linear(self.fc_lstm_dim_2, self.n_output)

def forward(self, X_3d):
# 输入的大小 B * N * 3 * W * H
cnn_embed_seq = []
for t in range(X_3d.size(1)): #枚举N帧
x = self.resnet(X_3d[:, t, :, :, :]) # B * 1 * 1000
x = x.view(x.size(0), -1) # B * 1000
x = self.bn_res_1(self.fc_res_1(x)) # B * 1000 => B * 512
x = F.relu(x)
x = self.fc_res_2(x) # B * 512 => B * 1000
cnn_embed_seq.append(x)

cnn_embed_seq = torch.stack(cnn_embed_seq, dim=0).transpose_(0, 1)
# 此时我们有cnn_embed_seq = (B * N * 1000)

# 为了提高内存的利用率和效率,调用flatten_parameters让parameter的数据存放成contiguous chunk(连续的块)。类似我们调用tensor.contiguous
self.LSTM.flatten_parameters()

RNN_out, (h_n, h_c) = self.LSTM(cnn_embed_seq, None)
# h_c shape (n_layers, B, hidden_size),默认值为(1, B, 1000)
# h_n shape (n_layers, B, hidden_size),默认值为(1, B, 1000)
# RNN_out = (B * N * 1000)
# None represents zero initial hidden state

# FC layers
x_rnn = RNN_out.contiguous().view(-1, self.lstm_hidden_dim) # BN * 1000
x_rnn = self.bn_lstm_1(self.fc_lstm_1(x_rnn)) # BN * 1000 => BN * 256
x_rnn = F.relu(x_rnn)
x_rnn = self.bn_lstm_2(self.fc_lstm_2(x_rnn)) # BN * 256 => BN * 128
x_rnn = F.relu(x_rnn)
x_rnn = self.fc_lstm_3(x_rnn) # BN * 8
return x_rnn.view(X_3d.size(0), -1) # return B * 8N

​ 这里涉及到LSTM的输入和输出,可以参考官网上的参数介绍。注意到最后的全连接层的维度变化,最后之所以预测B * 8N,是因为每个样本都有N帧,我们需要预测出每一帧的关节体参数,至于为什么不是$B \times N \times8$,这倒不是很重要,反正在back propagation的时候我们只需要准确地实现loss function,都能回归出来。

​ 代码中还提供了一个no_lstm版本,

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
class ScrewNet_NoLSTM(nn.Module):
def __init__(self, seq_len=16, fc_replace_lstm_dim=1000, n_output=8):
super(ScrewNet_NoLSTM, self).__init__()
self.fc_replace_lstm_seq_dim = fc_replace_lstm_dim * seq_len
...
self.fc_replace_lstm = nn.Linear(self.fc_replace_lstm_seq_dim, self.fc_replace_lstm_seq_dim)
def forward(self, X_3d):
...
# FC replacing LSTM layer
# cnn_embed_seq = (B * N * 1000)
cnn_embed_seq = cnn_embed_seq.contiguous().view(cnn_embed_seq.size(0), -1)
# (B * N * 1000) => (B * 1000N)
x_rnn = F.relu(self.fc_replace_lstm(cnn_embed_seq)) #(B * 1000N) => (B * 1000N)
x_rnn = x_rnn.view(-1, self.fc_replace_lstm_dim) #(BN * 1000)
#后面就继续连接FC层和上面一模一样了
...

loss.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
def articulation_lstm_loss_spatial_distance(pred, target, wt_on_ortho=1.):
""" Based on Spatial distance. Please refer to the paper for more details.
"""
pred = pred.view(pred.size(0), -1, 8)[:, 1:, :] # We don't need the first row as it is for single image
# (B, N - 1, 8)

# Spatial Distance loss,计算的是轴角度的误差以及轴平移的误差
dist_err = orientation_difference_bw_plucker_lines(target, pred) ** 2 + \
2. * distance_bw_plucker_lines(target, pred) ** 2

# Configuration Loss,也就是theta和d的误差
conf_err = theta_config_error(target, pred) ** 2 + d_config_error(target, pred) ** 2

err = dist_err + conf_err
loss = torch.mean(err)

# Ensure l_hat has norm 1.
# 单位向量约束
loss += torch.mean((torch.norm(pred[:, :, :3], dim=-1) - 1.) ** 2)

# Ensure orthogonality between l_hat and m
# l和m的正交约束
loss += wt_on_ortho * torch.mean(torch.abs(torch.sum(torch.mul(pred[:, :, :3], pred[:, :, 3:6]), dim=-1)))

if torch.isnan(loss):
print("target: Min: {}, Max{}".format(target.min(), target.max()))
print("Prediction: Min: {}, Max{}".format(pred.min(), pred.max()))
print("L2 error: {}".format(torch.mean((target - pred) ** 2)))
print("Distance loss:{}".format(torch.mean(orientation_difference_bw_plucker_lines(target, pred) ** 2)))
print("Orientation loss:{}".format(torch.mean(distance_bw_plucker_lines(target, pred) ** 2)))
print("Configuration loss:{}".format(torch.mean(conf_err)))

return loss

def distance_bw_plucker_lines(target, prediction, eps=1e-10):
""" Input shapes Tensors: Batch X #Images X 8
# Based on formula from Plücker Coordinates for Lines in the Space by Prof. Yan-bin Jia
# Verified by https://keisan.casio.com/exec/system/1223531414
"""
norm_cross_prod = torch.norm(torch.cross(target[:, :, :3], prediction[:, :, :3], dim=-1), dim=-1)
dist = torch.zeros_like(norm_cross_prod)

# Checking for Parallel Lines
if torch.any(norm_cross_prod <= eps):
zero_idxs = (norm_cross_prod <= eps).nonzero(as_tuple=True)
scales = torch.norm(prediction[zero_idxs][:, :3], dim=-1) / torch.norm(target[zero_idxs][:, :3], dim=-1) + eps
dist[zero_idxs] = torch.norm(torch.cross(target[zero_idxs][:, :3], (
target[zero_idxs][:, 3:6] - prediction[zero_idxs][:, 3:6] / scales.unsqueeze(-1))), dim=-1) / (
torch.mul(target[zero_idxs][:, :3], target[zero_idxs][:, :3]).sum(dim=-1) + eps)

# Skew Lines: Non zero cross product
nonzero_idxs = (norm_cross_prod > eps).nonzero(as_tuple=True)
dist[nonzero_idxs] = torch.abs(
torch.mul(target[nonzero_idxs][:, :3], prediction[nonzero_idxs][:, 3:6]).sum(dim=-1) + torch.mul(
target[nonzero_idxs][:, 3:6], prediction[nonzero_idxs][:, :3]).sum(dim=-1)) / (
norm_cross_prod[nonzero_idxs] + eps)
return dist


def orientation_difference_bw_plucker_lines(target, prediction, eps=1e-6):
""" Input shapes Tensors: (B, N, 8)
range of arccos ins [0, pi)"""
return torch.acos(torch.clamp(torch.mul(target[:, :, :3], prediction[:, :, :3]).sum(dim=-1) /
(torch.norm(target[:, :, :3], dim=-1) * torch.norm(prediction[:, :, :3], dim=-1) + eps), min=-1, max=1))


def theta_config_error(target, prediction):
# theta的loss
rot_tar = angle_axis_to_rotation_matrix(target[:, :, :3], target[:, :, 6]).view(-1, 3, 3)
rot_pred = angle_axis_to_rotation_matrix(prediction[:, :, :3], prediction[:, :, 6]).view(-1, 3, 3)
I_ = torch.eye(3).reshape((1, 3, 3))
I_ = I_.repeat(rot_tar.size(0), 1, 1).to(target.device)
return torch.norm(I_ - torch.bmm(rot_pred, rot_tar.transpose(1, 2)), dim=(1, 2), p=2).view(target.shape[:2])

def d_config_error(target, prediction):
tar_d = target[:, :, 7].unsqueeze(-1)
pred_d = prediction[:, :, 7].unsqueeze(-1)
tar_d = target[:, :, :3] * tar_d
pred_d = prediction[:, :, :3] * pred_d
return (tar_d - pred_d).norm(dim=-1)

2.dataset.py

​ 我们先从简单的两张图片的数据集RigidTransformDataset入手,注意到数据集只需要override __len__和 __getitem__。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
"""
Data loader class for the 2-imgs ablated version
"""
class RigidTransformDataset(Dataset):
def __init__(self,
ntrain,
root_dir,
n_dof=1,
norm_factor=1.,
transform=None):
super(RigidTransformDataset, self).__init__()

self.root_dir = root_dir
self.labels_data = None
self.length = ntrain
self.n_dof = n_dof
self.normalization_factor = norm_factor
self.transform = transform
self.augmentation_factor = 15

def __len__(self):
return self.length

def __getitem__(self, idx, imgs_per_object=16):
if self.labels_data is None:
self.labels_data = h5py.File(os.path.join(self.root_dir, 'complete_data.hdf5'), 'r')

obj_idx = int(idx / self.augmentation_factor)
obj_data_idx = idx % self.augmentation_factor + 1
obj_data = self.labels_data['obj_' + str(obj_idx).zfill(6)] #编号补齐0到6位

# Load depth image
depth_imgs = torch.tensor([obj_data['depth_imgs'][0],
obj_data['depth_imgs'][obj_data_idx]]) #只取开始和最后的两张图片
#此时为 N * W * H
depth_imgs.unsqueeze_(1).float() #使用unsqueeze_添加一维,变成 N * 1 * W * H
depth_imgs = torch.cat((depth_imgs, depth_imgs, depth_imgs), dim=1)
# 深度通道复制三份,变成N * 3 * W * H(这真的会有用吗???)

# # Load labels
pt1 = obj_data['moving_frame_in_world'][0, :] # 世界系下的四元数1
pt2 = obj_data['moving_frame_in_world'][obj_data_idx, :] # 世界西夏的四元数2
pt1_T_pt2 = change_frames(pt1, pt2) #计算出相对pose

# Object pose in world
obj_pose_in_world = np.array(obj_data['embedding_and_params'])[-7:] # obj_pose, obj_quat_wxyz
obj_T_pt1 = change_frames(obj_pose_in_world, pt1) #也就是论文中提到的向着base object frame转换

# 用screw参数创建标签,label := <l_hat, m, theta, d> = <3, 3, 1, 1>
l_hat, m, theta, d = transform_to_screw(translation=pt1_T_pt2[:3],
quat_in_wxyz=pt1_T_pt2[3:])

# Convert line in object_local_coordinates
new_l = transform_plucker_line(np.concatenate((l_hat, m)), trans=obj_T_pt1[:3], quat=obj_T_pt1[3:])
label = np.concatenate((new_l, [theta], [d])) # This defines frames wrt pt 1

# Normalize labels
label[3:6] /= self.normalization_factor
# Scaling m appropriately

label = torch.from_numpy(label).float()
sample = {'depth': depth_imgs,
'label': label} #最终一个GT以dict的形式打包传出

return sample

​ 多张图片的其实就大差不差了,不过里面有一个细节我们需要深究一下。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
class ArticulationDataset(Dataset):
def __getitem__(self, idx):
...
pt1 = moving_body_poses[0, :] # Fixed common reference frame
for i in range(len(moving_body_poses) - 1):
pt2 = moving_body_poses[i + 1, :]
pt1_T_pt2 = change_frames(pt1, pt2)

# Generating labels in screw notation: label := <l_hat, m, theta, d> = <3, 3, 1, 1>
l_hat, m, theta, d = transform_to_screw(translation=pt1_T_pt2[:3],
quat_in_wxyz=pt1_T_pt2[3:])

# Convert line in object_local_coordinates
new_l = transform_plucker_line(np.concatenate((l_hat, m)), trans=obj_T_pt1[:3], quat=obj_T_pt1[3:])
label[i, :] = np.concatenate((new_l, [theta], [d])) # This defines frames wrt pt 1

def transform_to_screw(translation, quat_in_wxyz, tol=1e-6):
dq = dq3d.dualquat(dq3d.quat(quat_as_xyzw(quat_in_wxyz)), translation)
screw = dual_quaternion_to_screw(dq, tol)
return screw


def dual_quaternion_to_screw(dq, tol=1e-6):
l_hat, theta = tf3d.quaternions.quat2axangle(np.array([dq.real.w, dq.real.x, dq.real.y, dq.real.z]))

if theta < tol or abs(theta - np.pi) < tol:
t_vec = dq.translation()
l_hat = t_vec / (np.linalg.norm(t_vec) + 1e-10)
theta = tol # This makes sure that tan(theta) is defined
else:
t_vec = (2 * tf3d.quaternions.qmult(dq.dual.data, tf3d.quaternions.qconjugate(dq.real.data)))[
1:] # taking xyz from wxyz

d = t_vec.dot(l_hat)
m = (1 / 2) * (np.cross(t_vec, l_hat) + ((t_vec - d * l_hat) / np.tan(theta / 2)))
return l_hat, m, theta, d

​ 只有理解了transform_to_screw这个函数在做什么,我们才真正摸索到了Screw Theory的实质。主要可以参考这篇文献。总体逻辑就是四元数可以表示三维旋转,而对偶四元数可以同时表示三维旋转和平移,所以使用对偶四元数来表示Screw Parameter就是很合理的事情,满足以下推导:

​ 空间任意刚体运动,可分解为刚体上某一点的平移,以及绕经过此点的旋转轴的转动,我们令这个点为连体基坐标原点,我们记作$R$和$t$,旋转矩阵$R$对应的四元数为$p$,由$R$和$t$可以计算出对偶四元数$q$。根据Chasles theorem(Screw theory,沙勒定理)我们又知道:空间任意刚体运动,均可看作有限螺旋运动,即均可表示为绕一轴的旋转和沿该轴的平移,参数可以记为$(\textbf{l},\textbf{m}, \theta, d)$。

​ 首先,四元数$p$转化为轴角表达$p=(cos(\frac{\theta}{2}),\textbf{l}sin(\frac{\theta}{2}))$就可以直接得到$\textbf{l}$和$\theta$,参数物理意义完全相同。其余参数满足下式:
$$
d=\textbf{t}\cdot\textbf{l}=(2qp^*)\cdot\textbf{l} \\
m=\frac{1}{2}(\textbf{t}\times\textbf{l}+(\textbf{t}-d\textbf{l})\cot\frac{\theta}{2})
$$

3.train_model.py

​ 里面涉及到三种不同的模型的定义,封装地也很好,总体逻辑还是非常简单易懂的。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
trainset = ...
testset = ...
loss_fn = ...
network = ...
testloader = torch.utils.data.DataLoader(testset, batch_size=args.batch,
shuffle=True, num_workers=args.nwork,
pin_memory=True)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch,
shuffle=True, num_workers=args.nwork,
pin_memory=True)
# Load Saved weights
if args.load_wts:
network.load_state_dict(torch.load(args.wts_dir + args.prior_wts + '.net'))
# setup trainer
if torch.cuda.is_available():
device = torch.device(args.device)
else:
device = torch.device('cpu')
optimizer = torch.optim.Adam(network.parameters(),
lr=args.learning_rate,
weight_decay=1e-2)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=lr_schedule, gamma=lr_gamma)
trainer = ModelTrainer(model=network,
train_loader=trainloader,
test_loader=testloader,
optimizer=optimizer,
scheduler=scheduler,
criterion=loss_fn,
epochs=args.epochs,
name=args.name,
test_freq=args.val_freq,
device=args.device)
# train
best_model = trainer.train()

4.model_trainer.py

​ 其实这就没啥好说的了,无非就是训练(算loss,反向传播,画图,训练日志,保存模型)和测试(计算均值和方差)。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142

class ModelTrainer(object):
def __init__(self, *kwargs):
pass

def train(self):
best_tloss = 1e8
for epoch in range(self.epochs + 1):
sys.stdout.flush()
loss = self.train_epoch(epoch)
self.losses.append(loss)
self.writer.add_scalar('Loss/train', loss, epoch)

if epoch % self.test_freq == 0:
tloss = self.test_epoch(epoch)
self.tlosses.append(tloss)
self.plot_losses()
self.writer.add_scalar('Loss/validation', tloss, epoch)

if tloss < best_tloss:
print('saving model.')
net_fname = os.path.join(self.wts_dir, str(self.name) + '.net')
torch.save(self.model.state_dict(), net_fname) # 把表现更好的模型存到本地
best_tloss = tloss

self.scheduler.step()

# Visualize gradients
total_norm = 0.
nan_count = 0
for tag, parm in self.model.named_parameters():
if torch.isnan(parm.grad).any():
print("Encountered NaNs in gradients at {} layer".format(tag))
nan_count += 1
else:
self.writer.add_histogram(tag, parm.grad.data.cpu().numpy(), epoch)
param_norm = parm.grad.data.norm(2)
total_norm += param_norm.item() ** 2

total_norm = total_norm ** (1. / 2)
self.writer.add_scalar('Gradient/2-norm', total_norm, epoch)
if nan_count > 0:
raise ValueError("Encountered NaNs in gradients")

# plot losses one more time
self.plot_losses()
# re-load the best state dictionary that was saved earlier.
self.model.load_state_dict(torch.load(net_fname, map_location='cpu'))

# export scalar data to JSON for external processing
self.writer.export_scalars_to_json("./all_scalars.json")
self.writer.close()
return self.model

def train_epoch(self, epoch):
start = time.time()
running_loss = 0
batches_per_dataset = len(self.trainloader.dataset) / self.trainloader.batch_size
self.model.train() # Put model in training mode

for i, X in enumerate(self.trainloader):
self.optimizer.zero_grad()
depth, labels = X['depth'].to(self.device), \
X['label'].to(self.device)

y_pred = self.model(depth)
loss = self.criterion(y_pred, labels)
if loss.data == -float('inf'):
print('inf loss caught, not backpropping')
running_loss += -1000
else:
loss.backward()

# `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 10.)

self.optimizer.step()
running_loss += loss.item()

stop = time.time()
print('Epoch %s - Train Loss: %.5f Time: %.5f' % (str(epoch).zfill(3),
running_loss / batches_per_dataset,
stop - start))
return running_loss / batches_per_dataset

def test_epoch(self, epoch):
start = time.time()
running_loss = 0
batches_per_dataset = len(self.testloader.dataset) / self.testloader.batch_size
self.model.eval() # Put batch norm layers in eval mode

with torch.no_grad():
for i, X in enumerate(self.testloader):
depth, labels = X['depth'].to(self.device), \
X['label'].to(self.device)
y_pred = self.model(depth)
loss = self.criterion(y_pred, labels)
running_loss += loss.item()

stop = time.time()
print('Epoch %s - Test Loss: %.5f Euc. Time: %.5f' % (str(epoch).zfill(3),
running_loss / batches_per_dataset,
stop - start))
return running_loss / batches_per_dataset

def test_best_model(self, best_model, fname_suffix='', dual_quat_mode=False):
best_model.eval() # Put model in evaluation mode
...

with torch.no_grad():
for X in self.testloader:
depth, all_labels, labels = X['depth'].to(self.device), \
X['all_labels'].to(self.device), \
X['label'].to(self.device)
y_pred = best_model(depth, all_labels)
y_pred = y_pred.view(y_pred.size(0), -1, 8)

if dual_quat_mode:
y_pred = dual_quaternion_to_screw_batch_mode(y_pred)
labels = dual_quaternion_to_screw_batch_mode(labels)

err = labels - y_pred
all_l_hat_err = torch.cat(
(all_l_hat_err, torch.mean(torch.norm(err[:, :, :3], dim=-1), dim=-1).cpu()))
all_m_err = torch.cat((all_m_err, torch.mean(torch.norm(err[:, :, 3:6], dim=-1), dim=-1).cpu()))
all_q_err = torch.cat((all_q_err, torch.mean(err[:, :, 6], dim=-1).cpu()))
all_d_err = torch.cat((all_d_err, torch.mean(err[:, :, 7], dim=-1).cpu()))

all_l_hat_std = torch.cat(
(all_l_hat_std, torch.std(torch.norm(err[:, :, :3], dim=-1), dim=-1).cpu()))
all_m_std = torch.cat((all_m_std, torch.std(torch.norm(err[:, :, 3:6], dim=-1), dim=-1).cpu()))
all_q_std = torch.cat((all_q_std, torch.std(err[:, :, 6], dim=-1).cpu()))
all_d_std = torch.cat((all_d_std, torch.std(err[:, :, 7], dim=-1).cpu()))

# Plot variation of screw axis
pass

def plot_grad_flow(self, named_parameters):
pass

def plot_losses(self):
pass

5.Other Modules

​ 其他辅助模组就暂时不继续占据篇幅了。注意到还存在一个noisy_models.py,引用了这个仓库,可能是对应的paper的数据增强手段,此处不表。

Author

Kami-code

Posted on

2021-12-20

Updated on

2022-02-14

Licensed under

Comments