|  | @@ -3608,7 +3608,7 @@
 | 
											
												
													
														|  |          "    with torch.no_grad():\n",
 |  |          "    with torch.no_grad():\n",
 | 
											
												
													
														|  |          "        logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))\n",
 |  |          "        logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))\n",
 | 
											
												
													
														|  |          "\n",
 |  |          "\n",
 | 
											
												
													
														|  | -        "    weights = torch.concatenate(QKs)  # layers * heads * tokens * frames    \n",
 |  | 
 | 
											
												
													
														|  | 
 |  | +        "    weights = torch.cat(QKs)  # layers * heads * tokens * frames    \n",
 | 
											
												
													
														|  |          "    weights = weights[:, :, :, : duration // AUDIO_SAMPLES_PER_TOKEN].cpu()\n",
 |  |          "    weights = weights[:, :, :, : duration // AUDIO_SAMPLES_PER_TOKEN].cpu()\n",
 | 
											
												
													
														|  |          "    weights = medfilt(weights, (1, 1, 1, medfilt_width))\n",
 |  |          "    weights = medfilt(weights, (1, 1, 1, medfilt_width))\n",
 | 
											
												
													
														|  |          "    weights = torch.tensor(weights * qk_scale).softmax(dim=-1)\n",
 |  |          "    weights = torch.tensor(weights * qk_scale).softmax(dim=-1)\n",
 |