Prechádzať zdrojové kódy

allow test_transcribe to run on CPU when CUDA is not available

Jong Wook Kim 2 rokov pred
rodič
commit
b1d213c0c7
2 zmenil súbory, kde vykonal 5 pridanie a 3 odobranie
  1. 1 1
      .github/workflows/test.yml
  2. 4 2
      tests/test_transcribe.py

+ 1 - 1
.github/workflows/test.yml

@@ -18,7 +18,7 @@ jobs:
             pytorch-version: 1.10.2
             pytorch-version: 1.10.2
     steps:
     steps:
       - uses: conda-incubator/setup-miniconda@v2
       - uses: conda-incubator/setup-miniconda@v2
-      - run: conda install -n test python=${{ matrix.python-version }} pytorch=${{ matrix.pytorch-version }} cpuonly -c pytorch
+      - run: conda install -n test ffmpeg python=${{ matrix.python-version }} pytorch=${{ matrix.pytorch-version }} cpuonly -c pytorch
       - uses: actions/checkout@v2
       - uses: actions/checkout@v2
       - run: echo "$CONDA/envs/test/bin" >> $GITHUB_PATH
       - run: echo "$CONDA/envs/test/bin" >> $GITHUB_PATH
       - run: pip install pytest
       - run: pip install pytest

+ 4 - 2
tests/test_transcribe.py

@@ -1,13 +1,15 @@
 import os
 import os
 
 
 import pytest
 import pytest
+import torch
 
 
 import whisper
 import whisper
 
 
 
 
-@pytest.mark.parametrize('model_name', whisper.available_models())
+@pytest.mark.parametrize("model_name", whisper.available_models())
 def test_transcribe(model_name: str):
 def test_transcribe(model_name: str):
-    model = whisper.load_model(model_name).cuda()
+    device = "cuda" if torch.cuda.is_available() else "cpu"
+    model = whisper.load_model(model_name).to(device)
     audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac")
     audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac")
 
 
     language = "en" if model_name.endswith(".en") else None
     language = "en" if model_name.endswith(".en") else None