import torch
class test_cat(torch.nn.Module):
def __init__(self) ->None:
super().__init__()
self.ones = torch.ones(1,2,4)
def forward(self,x):
y =torch.cat([x[:,:,1,:],self.ones], dim=-1)
return y
input = torch. zeros(1,2,3,4)
model = test_cat()
trace = torch.jit.trace(model,input)
trace.save("test_cat.pt")