import torch
from torch import nn
import torchviz
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.mul = nn.MultiheadAttention(128, 4, batch_first=True)#多头注意力机制的模块。128 表示输入的嵌入维度为 128,参数 4 表示使用 4 个注意力头,batch_first=True 表示输入的第一个维度是批量大小
def forward(self, x):
y, _ = self.mul(x, x, x)
return y
x = torch.randn(1, 240 * 3, 128)#这个函数用于生成指定形状的张量,其中的元素是从均值为0、标准差为1的正态分布中随机采样得到的。
#torch.save(x, 'x.txt')
print(x)
print(x.shape)
my = MyModel()
'''
dot = torchviz.make_dot(my(x), params=dict(my.named_parameters()))
dot.format = 'svg'
#dot.render(filename='model_graph', format='png')
dot.render(filename='model_graph')
'''
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter("./tensorboard_otameshi")
writer.add_graph(my, x)
writer.close()
这个网络,用注释中的代码 torchviz 可以输出计算图。可是用add_graph 就是报错,有大佬指教一下,这个该怎么调试吗?代码很短,有空也可以跑跑?
出错行:writer.add_graph(my, x)
raise TracingCheckError(*diag_info)
torch.jit._trace.TracingCheckError: Tracing failed sanity checks!
ERROR: Graphs differed across invocations!
Graph diff:
graph(%self.1 : __torch__.MyModel,
%x : Tensor):
%mul : __torch__.torch.nn.modules.activation.MultiheadAttention = prim::GetAttr[name="mul"](%self.1)
+ %4 : bool = prim::Constant[value=1](), scope: __module.mul # D:\myProgram\ideaJava\yiZhiXiangMuZu\RWKV\BlinkDL_ChatRWKV\ChatRWKV\venv\Lib\site-packages\torch\nn\modules\activation.py:1196:0
- %4 : NoneType = prim::Constant(), scope: __module.mul
? ^
+ %5 : NoneType = prim::Constant(), scope: __module.mul
。。。。。。。。。。。。。。。。。
+ return (%14)
? ^
First diverging operator:
Node diff:
- %mul : __torch__.torch.nn.modules.activation.MultiheadAttention = prim::GetAttr[name="mul"](%self.1)
+ %mul : __torch__.torch.nn.modules.activation.___torch_mangle_1.MultiheadAttention = prim::GetAttr[name="mul"](%self.1)
? ++++++++++++++++++
谢谢
--
修改:feng321 FROM 114.99.170.*
FROM 114.99.170.*
![单击此查看原图](//static.mysmth.net/nForum/att/Python/168769/2731/middle)