test_timing.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. import pytest
  2. import numpy as np
  3. import scipy.ndimage
  4. import torch
  5. from whisper.timing import dtw_cpu, dtw_cuda, median_filter
  6. sizes = [
  7. (10, 20), (32, 16), (123, 1500), (234, 189),
  8. ]
  9. shapes = [
  10. (10,), (1, 15), (4, 5, 345), (6, 12, 240, 512),
  11. ]
  12. @pytest.mark.parametrize("N, M", sizes)
  13. def test_dtw(N: int, M: int):
  14. steps = np.concatenate([np.zeros(N - 1), np.ones(M - 1)])
  15. np.random.shuffle(steps)
  16. x = np.random.random((N, M)).astype(np.float32)
  17. i, j, k = 0, 0, 0
  18. trace = []
  19. while True:
  20. x[i, j] -= 1
  21. trace.append((i, j))
  22. if k == len(steps):
  23. break
  24. if k + 1 < len(steps) and steps[k] != steps[k + 1]:
  25. i += 1
  26. j += 1
  27. k += 2
  28. continue
  29. if steps[k] == 0:
  30. i += 1
  31. if steps[k] == 1:
  32. j += 1
  33. k += 1
  34. trace = np.array(trace).T
  35. dtw_trace = dtw_cpu(x)
  36. assert np.allclose(trace, dtw_trace)
  37. @pytest.mark.requires_cuda
  38. @pytest.mark.parametrize("N, M", sizes)
  39. def test_dtw_cuda_equivalence(N: int, M: int):
  40. x_numpy = np.random.randn(N, M).astype(np.float32)
  41. x_cuda = torch.from_numpy(x_numpy).cuda()
  42. trace_cpu = dtw_cpu(x_numpy)
  43. trace_cuda = dtw_cuda(x_cuda)
  44. assert np.allclose(trace_cpu, trace_cuda)
  45. @pytest.mark.parametrize("shape", shapes)
  46. def test_median_filter(shape):
  47. x = torch.randn(*shape)
  48. for filter_width in [3, 5, 7, 13]:
  49. filtered = median_filter(x, filter_width)
  50. # using np.pad to reflect-pad, because Scipy's behavior is different near the edges.
  51. pad_width = filter_width // 2
  52. padded_x = np.pad(x, [(0, 0)] * (x.ndim - 1) + [(pad_width, pad_width)], mode="reflect")
  53. scipy_filtered = scipy.ndimage.median_filter(padded_x, [1] * (x.ndim - 1) + [filter_width])
  54. scipy_filtered = scipy_filtered[..., pad_width:-pad_width]
  55. assert np.allclose(filtered, scipy_filtered)
  56. @pytest.mark.requires_cuda
  57. @pytest.mark.parametrize("shape", shapes)
  58. def test_median_filter_equivalence(shape):
  59. x = torch.randn(*shape)
  60. for filter_width in [3, 5, 7, 13]:
  61. filtered_cpu = median_filter(x, filter_width)
  62. filtered_gpu = median_filter(x.cuda(), filter_width).cpu()
  63. assert np.allclose(filtered_cpu, filtered_gpu)