From c09986e580652f2e9d990602298fbf4571126ea6 Mon Sep 17 00:00:00 2001 From: James Cross Date: Tue, 3 Dec 2019 15:43:50 -0800 Subject: [PATCH] test scriptify MultiheadAttention (#670) Summary: Pull Request resolved: https://github.com/pytorch/translate/pull/670 We introduce a new test suite where we will iteratively ensure that various model components are TorchScript compliant (JIT-able). Differential Revision: D18799715 fbshipit-source-id: b4a1486080f6791d1fcfeaa51edc4a38c4fadc20 --- pytorch_translate/test/test_export_models.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) create mode 100644 pytorch_translate/test/test_export_models.py diff --git a/pytorch_translate/test/test_export_models.py b/pytorch_translate/test/test_export_models.py new file mode 100644 index 00000000..eb05fd7b --- /dev/null +++ b/pytorch_translate/test/test_export_models.py @@ -0,0 +1,13 @@ +#!/usr/bin/env python3 + +import unittest + +import torch +from fairseq.modules import multihead_attention + + +class TestExportModels(unittest.TestCase): + @unittest.skip("TDD: placeholder for development") + def test_export_multihead_attention(self): + module = multihead_attention.MultiheadAttention(embed_dim=8, num_heads=2) + torch.jit.script(module)