torch.chunk(tensor, chunk_num, dim)与torch.cat()原理相反,它是将tensor按dim(行或列)分割成chunk_num个tensor块,返回的是一个元组。 a torch.Tensor([[4,5,7], [3,9,8], [9,6,7]])
b torch.chunk(a, 3, dim 1)
print(a)
pri…
pytorch.chunk是挨着挨着进行分块的,其实这个可以借鉴Focus的思想,隔一个分一个块
代码:
import numpy as np
import torch
data torch.from_numpy(np.random.rand(1, 6, 3, 5))
print(str(data))
for i, data_i in enumerate(data.chunk(…