diff --git a/aligner/tests/test_cli.py b/aligner/tests/test_cli.py index e784224..d2c69cd 100644 --- a/aligner/tests/test_cli.py +++ b/aligner/tests/test_cli.py @@ -5,14 +5,15 @@ import io import os +import re import subprocess import tempfile from contextlib import redirect_stderr from pathlib import Path -from unittest import SkipTest, TestCase from urllib.error import URLError from urllib.request import Request, urlopen +import pytest from typer.testing import CliRunner from ..classes import Segment @@ -21,41 +22,44 @@ VERBOSE_OVERRIDE = bool(os.environ.get("EVERYVOICE_VERBOSE_TESTS", False)) -class CLITest(TestCase): - def setUp(self) -> None: - self.runner = CliRunner() +@pytest.fixture(scope="class") +def runner(request) -> None: + request.cls.runner = CliRunner() - def test_main_help(self): + +@pytest.mark.usefixtures("runner") +class TestCLI: + def test_main_help(self, subtests): for help in "-h", "--help": - with self.subTest(help=help): + with subtests.test(help=help): result = self.runner.invoke(app, [help]) - self.assertEqual(result.exit_code, 0) - self.assertIn("align", result.stdout) - self.assertIn("extract", result.stdout) + assert result.exit_code == 0 + assert "align" in result.stdout + assert "extract" in result.stdout - def test_sub_help(self): + def test_sub_help(self, subtests): for cmd in "align", "extract": for help in "-h", "--help": - with self.subTest(cmd=cmd, help=help): + with subtests.test(cmd=cmd, help=help): result = self.runner.invoke(app, [cmd, help]) - self.assertEqual(result.exit_code, 0) - self.assertIn("Usage:", result.stdout) - self.assertIn(cmd, result.stdout) + assert result.exit_code == 0 + assert "Usage:" in result.stdout + assert cmd in result.stdout - def test_align_empty_file(self): - with self.subTest("empty file"): + def test_align_empty_file(self, subtests): + with subtests.test("empty file"): result = self.runner.invoke(app, ["align", os.devnull, os.devnull]) - self.assertNotEqual(result.exit_code, 0) - self.assertRegex(result.output, r"(?s)is.*empty") + assert result.exit_code != 0 + assert re.search(r"(?s)is.*empty", result.output) - with self.subTest("file with only empty lines"): + with subtests.test("file with only empty lines"): with tempfile.TemporaryDirectory() as tmpdir: textfile = os.path.join(tmpdir, "emptylines.txt") with open(textfile, "w", encoding="utf8") as f: f.write("\n \n \n") result = self.runner.invoke(app, ["align", textfile, os.devnull]) - self.assertNotEqual(result.exit_code, 0) - self.assertRegex(result.output, r"(?s)is.*empty") + assert result.exit_code != 0 + assert re.search(r"(?s)is.*empty", result.output) def fetch_ras_test_file(self, filename, outputdir): repo, path = "https://github.com/ReadAlongs/Studio/", "/tests/data/" @@ -65,7 +69,7 @@ def fetch_ras_test_file(self, filename, outputdir): with open(os.path.join(outputdir, filename), "wb") as f: f.write(response.read()) - def test_align_something(self): + def test_align_something(self, subtests): with tempfile.TemporaryDirectory() as tmpdir: tmppath = Path(tmpdir) try: @@ -73,7 +77,7 @@ def test_align_something(self): self.fetch_ras_test_file("ej-fra.txt", tmpdir) self.fetch_ras_test_file("ej-fra.m4a", tmpdir) except URLError as e: # pragma: no cover - raise SkipTest( + raise pytest.skip( f"Can't fetch test data: {e}; skipping the test that depends on the Internet." ) txt = tmppath / "ej-fra.txt" @@ -90,31 +94,31 @@ def test_align_something(self): textgrid = tmppath / "ej-fra-16000.TextGrid" wav_out = tmppath / "ej-fra-16000.wav" - with self.subTest("ctc-segmenter align"): + with subtests.test("ctc-segmenter align"): result = self.runner.invoke(app, ["align", str(txt), str(wav)]) if result.exit_code != 0: os.system("ls -la " + tmpdir) print(result.output) - self.assertEqual(result.exit_code, 0) - self.assertTrue(textgrid.exists()) - self.assertTrue(wav_out.exists()) + assert result.exit_code == 0 + assert textgrid.exists() + assert wav_out.exists() - with self.subTest("ctc-segmenter extract"): + with subtests.test("ctc-segmenter extract"): result = self.runner.invoke( app, ["extract", str(textgrid), str(wav_out), str(tmppath / "out")] ) if result.exit_code != 0: print(result.output) - self.assertEqual(result.exit_code, 0) - self.assertTrue((tmppath / "out/metadata.psv").exists()) + assert result.exit_code == 0 + assert (tmppath / "out/metadata.psv").exists() with open(txt, encoding="utf8") as txt_f: non_blank_line_count = sum(1 for line in txt_f if line.strip()) for i in range(non_blank_line_count): - self.assertTrue((tmppath / f"out/wavs/segment{i}.wav")) + assert (tmppath / f"out/wavs/segment{i}.wav").exists() -class MiscTests(TestCase): +class TestMisc: def test_segment(self): segment = Segment("text", 500, 700, 0.42) - self.assertEqual(len(segment), 200) - self.assertEqual(repr(segment), "text (0.42): [ 500, 700)") + assert len(segment) == 200 + assert repr(segment) == "text (0.42): [ 500, 700)" diff --git a/pyproject.toml b/pyproject.toml index a1094f1..825e711 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,13 +45,14 @@ include = ["/aligner"] [project.optional-dependencies] dev = [ "black~=24.3", + "coverage", "flake8>=4.0.1", "gitlint-core>=0.19.0", "isort>=5.10.1", "mypy>=1.8.0", "pre-commit>=3.2.0", "pytest", - "coverage", + "pytest-subtests", ] [project.urls]