test_timing.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. import numpy as np
  2. import pytest
  3. import scipy.ndimage
  4. import torch
  5. from whisper.timing import dtw_cpu, dtw_cuda, median_filter
  6. sizes = [
  7. (10, 20),
  8. (32, 16),
  9. (123, 1500),
  10. (234, 189),
  11. ]
  12. shapes = [
  13. (10,),
  14. (1, 15),
  15. (4, 5, 345),
  16. (6, 12, 240, 512),
  17. ]
  18. @pytest.mark.parametrize("N, M", sizes)
  19. def test_dtw(N: int, M: int):
  20. steps = np.concatenate([np.zeros(N - 1), np.ones(M - 1)])
  21. np.random.shuffle(steps)
  22. x = np.random.random((N, M)).astype(np.float32)
  23. i, j, k = 0, 0, 0
  24. trace = []
  25. while True:
  26. x[i, j] -= 1
  27. trace.append((i, j))
  28. if k == len(steps):
  29. break
  30. if k + 1 < len(steps) and steps[k] != steps[k + 1]:
  31. i += 1
  32. j += 1
  33. k += 2
  34. continue
  35. if steps[k] == 0:
  36. i += 1
  37. if steps[k] == 1:
  38. j += 1
  39. k += 1
  40. trace = np.array(trace).T
  41. dtw_trace = dtw_cpu(x)
  42. assert np.allclose(trace, dtw_trace)
  43. @pytest.mark.requires_cuda
  44. @pytest.mark.parametrize("N, M", sizes)
  45. def test_dtw_cuda_equivalence(N: int, M: int):
  46. x_numpy = np.random.randn(N, M).astype(np.float32)
  47. x_cuda = torch.from_numpy(x_numpy).cuda()
  48. trace_cpu = dtw_cpu(x_numpy)
  49. trace_cuda = dtw_cuda(x_cuda)
  50. assert np.allclose(trace_cpu, trace_cuda)
  51. @pytest.mark.parametrize("shape", shapes)
  52. def test_median_filter(shape):
  53. x = torch.randn(*shape)
  54. for filter_width in [3, 5, 7, 13]:
  55. filtered = median_filter(x, filter_width)
  56. # using np.pad to reflect-pad, because Scipy's behavior is different near the edges.
  57. pad_width = filter_width // 2
  58. padded_x = np.pad(
  59. x, [(0, 0)] * (x.ndim - 1) + [(pad_width, pad_width)], mode="reflect"
  60. )
  61. scipy_filtered = scipy.ndimage.median_filter(
  62. padded_x, [1] * (x.ndim - 1) + [filter_width]
  63. )
  64. scipy_filtered = scipy_filtered[..., pad_width:-pad_width]
  65. assert np.allclose(filtered, scipy_filtered)
  66. @pytest.mark.requires_cuda
  67. @pytest.mark.parametrize("shape", shapes)
  68. def test_median_filter_equivalence(shape):
  69. x = torch.randn(*shape)
  70. for filter_width in [3, 5, 7, 13]:
  71. filtered_cpu = median_filter(x, filter_width)
  72. filtered_gpu = median_filter(x.cuda(), filter_width).cpu()
  73. assert np.allclose(filtered_cpu, filtered_gpu)