import torch class test_cat(torch.nn.Module): def __init__(self) ->None: super().__init__() self.ones = torch.ones(1,2,4) def forward(self,x): y =torch.cat([x[:,:,1,:],self.ones], dim=-1) return y input = torch. zeros(1,2,3,4) model = test_cat() trace = torch.jit.trace(model,input) trace.save("test_cat.pt")