class WordEmbeddingWrapper(torch.nn.Module):
def __init__( self, original_word_embedding):
super(WordEmbeddingWrapper, self).__init__()
def forward( self, input_ids ):
# TODO tensor of shape (1,n) to List
return hidden_states
# 创建WordEmbeddingWrapper实例
word_embedding_wrapper = WordEmbeddingWrapper(model.w['emb.weight'])
# 使用torch.jit.trace对模型进行跟踪编译
traced_model = torch.jit.trace(word_embedding_wrapper, example_input)
我调试了下,在WordEmbeddingWrapper(model.w['emb.weight'])语句中(也就是对象初始化的时候),并没有执行forward方法。但是在执行torch.jit.trace 的时候,先执行了forward方法。这是什么神奇的写法?trace的定义:
def trace(
func,
example_inputs=None,
optimize=None,
check_trace=True,
check_inputs=None,
check_tolerance=1e-5,
strict=True,
_force_outplace=False,
_module_class=None,
_compilation_unit=_python_cu,
example_kwarg_inputs=None,
_store_inputs=True,
):
在torch.jit.trace(word_embedding_wrapper, example_input)中trace调用的时候,是默认前两个?期望大佬解惑。
--
修改:feng321 FROM 120.242.238.*
FROM 120.242.238.*