123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109 |
- from functools import lru_cache
- import numpy as np
- import torch
- try:
- import triton
- import triton.language as tl
- except ImportError:
- raise RuntimeError("triton import failed; try `pip install --pre triton`")
- @triton.jit
- def dtw_kernel(
- cost, trace, x, x_stride, cost_stride, trace_stride, N, M, BLOCK_SIZE: tl.constexpr
- ):
- offsets = tl.arange(0, BLOCK_SIZE)
- mask = offsets < M
- for k in range(1, N + M + 1):
- tl.debug_barrier()
- p0 = cost + (k - 1) * cost_stride
- p1 = cost + k * cost_stride
- p2 = cost + k * cost_stride + 1
- c0 = tl.load(p0 + offsets, mask=mask)
- c1 = tl.load(p1 + offsets, mask=mask)
- c2 = tl.load(p2 + offsets, mask=mask)
- x_row = tl.load(x + (k - 1) * x_stride + offsets, mask=mask, other=0)
- cost_row = x_row + tl.minimum(tl.minimum(c0, c1), c2)
- cost_ptr = cost + (k + 1) * cost_stride + 1
- tl.store(cost_ptr + offsets, cost_row, mask=mask)
- trace_ptr = trace + (k + 1) * trace_stride + 1
- tl.store(trace_ptr + offsets, 2, mask=mask & (c2 <= c0) & (c2 <= c1))
- tl.store(trace_ptr + offsets, 1, mask=mask & (c1 <= c0) & (c1 <= c2))
- tl.store(trace_ptr + offsets, 0, mask=mask & (c0 <= c1) & (c0 <= c2))
- @lru_cache(maxsize=None)
- def median_kernel(filter_width: int):
- @triton.jit
- def kernel(
- y, x, x_stride, y_stride, BLOCK_SIZE: tl.constexpr
- ):
- row_idx = tl.program_id(0)
- offsets = tl.arange(0, BLOCK_SIZE)
- mask = offsets < y_stride
- x_ptr = x + row_idx * x_stride
- y_ptr = y + row_idx * y_stride
- LOAD_ALL_ROWS_HERE
- BUBBLESORT_HERE
- tl.store(y_ptr + offsets, MIDDLE_ROW_HERE, mask=mask)
- kernel = triton.JITFunction(kernel.fn)
- kernel.src = kernel.src.replace(
- " LOAD_ALL_ROWS_HERE",
- "\n".join(
- [
- f" row{i} = tl.load(x_ptr + offsets + {i}, mask=mask)"
- for i in range(filter_width)
- ]
- ),
- )
- kernel.src = kernel.src.replace(
- " BUBBLESORT_HERE",
- "\n\n".join(
- [
- "\n\n".join(
- [
- "\n".join(
- [
- f" smaller = tl.where(row{j} < row{j + 1}, row{j}, row{j + 1})",
- f" larger = tl.where(row{j} > row{j + 1}, row{j}, row{j + 1})",
- f" row{j} = smaller",
- f" row{j + 1} = larger",
- ]
- )
- for j in range(filter_width - i - 1)
- ]
- )
- for i in range(filter_width // 2 + 1)
- ]
- ),
- )
- kernel.src = kernel.src.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}")
- return kernel
- def median_filter_cuda(x: torch.Tensor, filter_width: int):
- """Apply a median filter of given width along the last dimension of x"""
- slices = x.contiguous().unfold(-1, filter_width, 1)
- grid = np.prod(slices.shape[:-2])
- kernel = median_kernel(filter_width)
- y = torch.empty_like(slices[..., 0])
- BLOCK_SIZE = 1 << (y.stride(-2) - 1).bit_length()
- kernel[(grid,)](y, x, x.stride(-2), y.stride(-2), BLOCK_SIZE=BLOCK_SIZE)
- return y
|