这是一篇文献上附上的代码,文献的目的是根据深度学习探究材料的原子性质与其他性质的关系。
我截取这一部分的目的应该是想要把原子特征和晶体特征联系起来,但是我并不能理解这一部分的代码,谢谢大家
https://github.com/txie-93/cgcnn/blob/master/cgcnn/model.py完整代码在这里,目标部分在第168行到最后
def pooling(self, atom_fea, crystal_atom_idx):
"""
Pooling the atom features to crystal features
N: Total number of atoms in the batch
N0: Total number of crystals in the batch
Parameters
----------
atom_fea: Variable(torch.Tensor) shape (N, atom_fea_len)
Atom feature vectors of the batch
crystal_atom_idx: list of torch.LongTensor of length N0
Mapping from the crystal idx to atom idx
"""
assert sum([len(idx_map) for idx_map in crystal_atom_idx]) ==\
atom_fea.data.shape[0]
summed_fea = [torch.mean(atom_fea[idx_map], dim=0, keepdim=True)
for idx_map in crystal_atom_idx]
return torch.cat(summed_fea, dim=0)
assert 是一个断言判断,主要代码是变量summed_fea ,它等同下面代码
summed_fea=[]
for idx_map in crystal_atom_idx:
r = torch.mean(atom_fea[idx_map], dim=0, keepdim=True)
summed_fea.append(r)
这段代码要分清楚 torch.mean、torch.cat以及各个参数是什么就行了
如果对你有帮助,可以点击我这个回答右上方的【采纳】按钮,给我个采纳吗,谢谢