triton_ops.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. from functools import lru_cache
  2. import numpy as np
  3. import torch
  4. try:
  5. import triton
  6. import triton.language as tl
  7. except ImportError:
  8. raise RuntimeError("triton import failed; try `pip install --pre triton`")
  9. @triton.jit
  10. def dtw_kernel(
  11. cost, trace, x, x_stride, cost_stride, trace_stride, N, M, BLOCK_SIZE: tl.constexpr
  12. ):
  13. offsets = tl.arange(0, BLOCK_SIZE)
  14. mask = offsets < M
  15. for k in range(1, N + M + 1): # k = i + j
  16. tl.debug_barrier()
  17. p0 = cost + (k - 1) * cost_stride
  18. p1 = cost + k * cost_stride
  19. p2 = cost + k * cost_stride + 1
  20. c0 = tl.load(p0 + offsets, mask=mask)
  21. c1 = tl.load(p1 + offsets, mask=mask)
  22. c2 = tl.load(p2 + offsets, mask=mask)
  23. x_row = tl.load(x + (k - 1) * x_stride + offsets, mask=mask, other=0)
  24. cost_row = x_row + tl.minimum(tl.minimum(c0, c1), c2)
  25. cost_ptr = cost + (k + 1) * cost_stride + 1
  26. tl.store(cost_ptr + offsets, cost_row, mask=mask)
  27. trace_ptr = trace + (k + 1) * trace_stride + 1
  28. tl.store(trace_ptr + offsets, 2, mask=mask & (c2 <= c0) & (c2 <= c1))
  29. tl.store(trace_ptr + offsets, 1, mask=mask & (c1 <= c0) & (c1 <= c2))
  30. tl.store(trace_ptr + offsets, 0, mask=mask & (c0 <= c1) & (c0 <= c2))
  31. @lru_cache(maxsize=None)
  32. def median_kernel(filter_width: int):
  33. @triton.jit
  34. def kernel(
  35. y, x, x_stride, y_stride, BLOCK_SIZE: tl.constexpr
  36. ): # x.shape[-1] == filter_width
  37. row_idx = tl.program_id(0)
  38. offsets = tl.arange(0, BLOCK_SIZE)
  39. mask = offsets < y_stride
  40. x_ptr = x + row_idx * x_stride # noqa: F841
  41. y_ptr = y + row_idx * y_stride
  42. LOAD_ALL_ROWS_HERE # noqa: F821
  43. BUBBLESORT_HERE # noqa: F821
  44. tl.store(y_ptr + offsets, MIDDLE_ROW_HERE, mask=mask) # noqa: F821
  45. kernel = triton.JITFunction(kernel.fn)
  46. kernel.src = kernel.src.replace(
  47. " LOAD_ALL_ROWS_HERE",
  48. "\n".join(
  49. [
  50. f" row{i} = tl.load(x_ptr + offsets + {i}, mask=mask)"
  51. for i in range(filter_width)
  52. ]
  53. ),
  54. )
  55. kernel.src = kernel.src.replace(
  56. " BUBBLESORT_HERE",
  57. "\n\n".join(
  58. [
  59. "\n\n".join(
  60. [
  61. "\n".join(
  62. [
  63. f" smaller = tl.where(row{j} < row{j + 1}, row{j}, row{j + 1})",
  64. f" larger = tl.where(row{j} > row{j + 1}, row{j}, row{j + 1})",
  65. f" row{j} = smaller",
  66. f" row{j + 1} = larger",
  67. ]
  68. )
  69. for j in range(filter_width - i - 1)
  70. ]
  71. )
  72. for i in range(filter_width // 2 + 1)
  73. ]
  74. ),
  75. )
  76. kernel.src = kernel.src.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}")
  77. return kernel
  78. def median_filter_cuda(x: torch.Tensor, filter_width: int):
  79. """Apply a median filter of given width along the last dimension of x"""
  80. slices = x.contiguous().unfold(-1, filter_width, 1)
  81. grid = np.prod(slices.shape[:-2])
  82. kernel = median_kernel(filter_width)
  83. y = torch.empty_like(slices[..., 0])
  84. BLOCK_SIZE = 1 << (y.stride(-2) - 1).bit_length()
  85. kernel[(grid,)](y, x, x.stride(-2), y.stride(-2), BLOCK_SIZE=BLOCK_SIZE)
  86. return y