1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
| x = torch.randn(1, 3, 4)
# 展平
x_flat = x.view(-1) # 展平为一维张量,-1表示自动计算这个维度的大小
# 改变形状
y = x.view(2, 12) # 2 x 12
z = x.reshape(3, 8) # 3 x 8(更通用)
# 插入维度
x_unsq = x.unsqueeze(0) # 在 dim=0 插一维: (1, 2, 3, 4)
x_unsq = x.unsqueeze(2) # (2, 3, 1, 4)
# 去掉大小为1的维度
x_sq = x_unsq.squeeze() # 去掉所有 =1 的维度
x_sq = x_unsq.squeeze(0) # 只去掉 dim=0
# 交换维度
x_perm = x.permute(0, 2, 1) # (2, 4, 3)
# 扩展维度
x_exp = x.expand(2, -1, -1) # 维度变为 (2, 3, 4),-1表示该维度不变
# 转置(2D 专用)
m = torch.randn(3, 4)
m_t = m.t() # (4, 3)
|