소스 검색

torch.concatenate -> torch.cat for compatibility

Jong Wook Kim 2 년 전
부모
커밋
f82bc59f5e
1개의 변경된 파일1개의 추가작업 그리고 1개의 파일을 삭제
  1. 1 1
      notebooks/Multilingual_ASR.ipynb

+ 1 - 1
notebooks/Multilingual_ASR.ipynb

@@ -3608,7 +3608,7 @@
         "    with torch.no_grad():\n",
         "        logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))\n",
         "\n",
-        "    weights = torch.concatenate(QKs)  # layers * heads * tokens * frames    \n",
+        "    weights = torch.cat(QKs)  # layers * heads * tokens * frames    \n",
         "    weights = weights[:, :, :, : duration // AUDIO_SAMPLES_PER_TOKEN].cpu()\n",
         "    weights = medfilt(weights, (1, 1, 1, medfilt_width))\n",
         "    weights = torch.tensor(weights * qk_scale).softmax(dim=-1)\n",