class ScaledDotProductAttention(nn.Module):
""" Scaled Dot-Product Attention """
def __init__(self, scale):
super().__init__()
self.scale = scale
self.softmax = nn.Softmax(dim=2)#dim=2表示在输入张量的第二个维度上进行softmax计算
def forward(self, q, k, v, mask=None):
u = torch.bmm(q, k.transpose(1, 2)) # 1.Matmul
u = u / self.scale # 2.Scale
if mask is not None:
u = u.masked_fill(mask, -np.inf) # 3.Mask
attn = self.softmax(u) # 4.Softmax
output = torch.bmm(attn, v) # 5.Output
return attn, output
if __name__ == "__main__":
n_q, n_k, n_v = 2, 4, 4
d_q, d_k, d_v = 128, 128, 64
batch = 5
q = torch.randn(batch, n_q, d_q)#用于生成指定形状的张量的函数.它会生成一个形状为(batch, n_q, d_q)的张量,其中batch表示批量大小,n_q表示查询的数量,d_q表示查询的维度。
k = torch.randn(batch, n_k, d_k)
v = torch.randn(batch, n_v, d_v)
mask = torch.zeros(batch, n_q, n_k).bool()#可以将张量的数据类型转换为布尔型(bool)
attention = ScaledDotProductAttention(scale=np.power(d_k, 0.5))
attn, output = attention(q, k, v, mask=mask)
#print(attn)
#print(output)
print(attn.shape)
print(output.shape)
attention = ScaledDotProductAttention(scale=np.power(d_k, 0.5)) 执行完后,初始化了一个ScaledDotProductAttention类型的对象attention,然后执行完attn, output = attention(q, k, v, mask=mask)后,就调用了forward方法,并得到推理结果。我知道pthon里的继承、面向对象,面向切面,可是forward方法是如何被调用的呢?
--
FROM 114.99.170.*