{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "FWl3jlFfVR37" }, "source": [ "# Programming Exercise 4: Transformers and Attention\n", "\n", "## Very Deep Learning (VDL) - Winter Semester 2023/24\n", "\n", "---\n", "\n", "### Group Details:\n", "\n", "- **Group Name:** Group 4\n", "\n", "### Members:\n", "\n", "- Frederick Phillips, 404986\n", "- Niklas Eberts, 409829\n", "- Muhammad Saad Najib, 423595\n", "- Rea Fernandes, 426401\n", "- Mayank Chetan Ahuja, 426518\n", "- Caina Rose Paul, 426291\n", "---\n", "\n", "**Instructions**: The tasks in this notebook are a part of Sheet 4. Look for `TODO` tags throughout the notebook and complete the sections with missing code. Once done, ensure all outputs are visible and correctly displayed. Save your notebook and submit the `.ipynb` file together with the exercise sheet PDF in a single ZIP file." ] }, { "cell_type": "markdown", "metadata": { "id": "fHv11KWsbpdi" }, "source": [ "## Introduction to Transformers\n", "Transformers have revolutionized the field of natural language processing and beyond. This tutorial will guide you through the core concepts of transformer models, focusing on attention mechanisms.\n", "\n", "Before diving into the practical aspects, familiarize yourself with the original Transformer paper: \"Attention Is All You Need\" by Vaswani et al. (2017). This will provide a solid theoretical foundation." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "id": "KtDpAjvhGm7n" }, "outputs": [], "source": [ "# Setup: Import necessary libraries\n", "import torch\n", "import torch.nn as nn\n", "import torch.optim as optim\n", "import torch.nn.functional as F\n", "from torch.utils.data import DataLoader, Dataset\n", "import math\n", "\n", "# Check if CUDA is available\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n" ] }, { "cell_type": "markdown", "metadata": { "id": "Zj-0why1by3E" }, "source": [ "## Task 1: Implementing Scaled Dot-Product Attention\n", "**Objective**: Implement the scaled dot-product attention mechanism as described in the Transformer paper.\n", "\n", "- **Subtask 1**: Define a function for scaled dot-product attention (1).\n", "- **Subtask 2**: Test the function with a small example (1)." ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "id": "UWaCwN7XI77C" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Output shape: torch.Size([10, 1, 512])\n", "Attention Weights shape: torch.Size([10, 1, 1])\n" ] } ], "source": [ "# TODO: Implement Scaled Dot-Product Attention\n", "\n", "class ScaledDotProductAttention(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", "\n", " def forward(self, query, key, value, mask=None):\n", " d_k = query.size(-1)\n", " scores = query.matmul(key.transpose(-2, -1)) / math.sqrt(d_k)\n", " \n", " if mask is not None:\n", " scores = scores.masked_fill(mask == 0, -1e9)\n", " \n", " p_attn = F.softmax(scores, dim=-1)\n", " return p_attn.matmul(value), p_attn\n", "\n", "# Initialize the attention mechanism\n", "attention = ScaledDotProductAttention()\n", "\n", "# Define query, key, value\n", "query = torch.rand(10, 1, 512)\n", "key = torch.rand(10, 1, 512)\n", "value = torch.rand(10, 1, 512)\n", "\n", "# Forward pass through the attention mechanism\n", "output, attention_weights = attention(query, key, value)\n", "\n", "print(\"Output shape: \", output.shape)\n", "print(\"Attention Weights shape: \", attention_weights.shape)" ] }, { "cell_type": "markdown", "metadata": { "id": "GLRxWgm7cBdD" }, "source": [ "## Task 2: Multi-Head Attention\n", "**Objective**: Understand and implement Multi-Head Attention.\n", "\n", "- **Subtask 1**: Implement the Multi-Head Attention module (1).\n", "- **Subtask 2**: Test the function with a small example (1)." ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "id": "GoXh3bWZJraR" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([64, 10, 512])\n" ] } ], "source": [ "# TODO: Implement Multi-Head Attention\n", "\n", "class MultiHeadAttention(nn.Module):\n", " def __init__(self, d_model, num_heads):\n", " super(MultiHeadAttention, self).__init__()\n", " assert d_model % num_heads == 0\n", "\n", " self.d_model = d_model\n", " self.num_heads = num_heads\n", " self.head_dim = d_model // num_heads\n", "\n", " self.q_linear = nn.Linear(d_model, d_model)\n", " self.k_linear = nn.Linear(d_model, d_model)\n", " self.v_linear = nn.Linear(d_model, d_model)\n", " self.out = nn.Linear(d_model, d_model)\n", "\n", " def forward(self, query, key, value):\n", " N = query.shape[0]\n", "\n", " # Get Q, K, V\n", " Q = self.q_linear(query)\n", " K = self.k_linear(key)\n", " V = self.v_linear(value)\n", "\n", " # Split the last dimension into (num_heads, head_dim)\n", " Q = Q.reshape(N, -1, self.num_heads, self.head_dim)\n", " K = K.reshape(N, -1, self.num_heads, self.head_dim)\n", " V = V.reshape(N, -1, self.num_heads, self.head_dim)\n", "\n", " # Compute scaled dot-product attention\n", " energy = torch.einsum(\"nqhd,nkhd->nhqk\", [Q, K])\n", " attention = torch.softmax(energy / (self.d_model ** (1 / 2)), dim=3)\n", " out = torch.einsum(\"nhql,nlhd->nqhd\", [attention, V]).reshape(N, -1, self.d_model)\n", "\n", " # Pass through the final linear layer\n", " out = self.out(out)\n", "\n", " return out\n", "\n", "mha = MultiHeadAttention(d_model=512, num_heads=8)\n", "x = torch.rand(64, 10, 512) # batch_size=64, sequence_length=10, d_model=512\n", "out = mha(x, x, x)\n", "print(out.shape) # Should print: torch.Size([64, 10, 512])" ] }, { "cell_type": "markdown", "metadata": { "id": "e77TefWbcgLH" }, "source": [ "## Task 3: Positional Encoding\n", "**Objective**: Implement positional encoding to add information about the sequence order.\n", "\n", "- **Subtask 1**: Implement the positional encoding module (1).\n", "- **Subtask 2**: Test the function with a small example (1)." ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "id": "sqF-HFdQJ4nB" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[[ 0.0000e+00, 1.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00,\n", " 1.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00, 1.0000e+00,\n", " 0.0000e+00, 1.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00,\n", " 1.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00, 1.0000e+00],\n", " [ 8.4147e-01, 5.4030e-01, 3.8767e-01, 9.2180e-01, 1.5783e-01,\n", " 9.8747e-01, 6.3054e-02, 9.9801e-01, 2.5116e-02, 9.9968e-01,\n", " 9.9998e-03, 9.9995e-01, 3.9811e-03, 9.9999e-01, 1.5849e-03,\n", " 1.0000e+00, 6.3096e-04, 1.0000e+00, 2.5119e-04, 1.0000e+00],\n", " [ 9.0930e-01, -4.1615e-01, 7.1471e-01, 6.9942e-01, 3.1170e-01,\n", " 9.5018e-01, 1.2586e-01, 9.9205e-01, 5.0217e-02, 9.9874e-01,\n", " 1.9999e-02, 9.9980e-01, 7.9621e-03, 9.9997e-01, 3.1698e-03,\n", " 9.9999e-01, 1.2619e-03, 1.0000e+00, 5.0238e-04, 1.0000e+00],\n", " [ 1.4112e-01, -9.8999e-01, 9.2997e-01, 3.6764e-01, 4.5775e-01,\n", " 8.8908e-01, 1.8816e-01, 9.8214e-01, 7.5285e-02, 9.9716e-01,\n", " 2.9995e-02, 9.9955e-01, 1.1943e-02, 9.9993e-01, 4.7547e-03,\n", " 9.9999e-01, 1.8929e-03, 1.0000e+00, 7.5357e-04, 1.0000e+00],\n", " [-7.5680e-01, -6.5364e-01, 9.9977e-01, -2.1631e-02, 5.9234e-01,\n", " 8.0569e-01, 2.4971e-01, 9.6832e-01, 1.0031e-01, 9.9496e-01,\n", " 3.9989e-02, 9.9920e-01, 1.5924e-02, 9.9987e-01, 6.3395e-03,\n", " 9.9998e-01, 2.5238e-03, 1.0000e+00, 1.0048e-03, 1.0000e+00],\n", " [-9.5892e-01, 2.8366e-01, 9.1320e-01, -4.0752e-01, 7.1207e-01,\n", " 7.0211e-01, 3.1027e-01, 9.5065e-01, 1.2526e-01, 9.9212e-01,\n", " 4.9979e-02, 9.9875e-01, 1.9904e-02, 9.9980e-01, 7.9244e-03,\n", " 9.9997e-01, 3.1548e-03, 1.0000e+00, 1.2559e-03, 1.0000e+00],\n", " [-2.7942e-01, 9.6017e-01, 6.8379e-01, -7.2968e-01, 8.1396e-01,\n", " 5.8092e-01, 3.6960e-01, 9.2919e-01, 1.5014e-01, 9.8866e-01,\n", " 5.9964e-02, 9.9820e-01, 2.3884e-02, 9.9971e-01, 9.5092e-03,\n", " 9.9995e-01, 3.7857e-03, 9.9999e-01, 1.5071e-03, 1.0000e+00],\n", " [ 6.5699e-01, 7.5390e-01, 3.4744e-01, -9.3770e-01, 8.9544e-01,\n", " 4.4518e-01, 4.2745e-01, 9.0404e-01, 1.7493e-01, 9.8458e-01,\n", " 6.9943e-02, 9.9755e-01, 2.7864e-02, 9.9961e-01, 1.1094e-02,\n", " 9.9994e-01, 4.4167e-03, 9.9999e-01, 1.7583e-03, 1.0000e+00],\n", " [ 9.8936e-01, -1.4550e-01, -4.3251e-02, -9.9906e-01, 9.5448e-01,\n", " 2.9827e-01, 4.8360e-01, 8.7529e-01, 1.9960e-01, 9.7988e-01,\n", " 7.9915e-02, 9.9680e-01, 3.1843e-02, 9.9949e-01, 1.2679e-02,\n", " 9.9992e-01, 5.0476e-03, 9.9999e-01, 2.0095e-03, 1.0000e+00],\n", " [ 4.1212e-01, -9.1113e-01, -4.2718e-01, -9.0417e-01, 9.8959e-01,\n", " 1.4389e-01, 5.3783e-01, 8.4305e-01, 2.2415e-01, 9.7455e-01,\n", " 8.9879e-02, 9.9595e-01, 3.5822e-02, 9.9936e-01, 1.4264e-02,\n", " 9.9990e-01, 5.6786e-03, 9.9998e-01, 2.2607e-03, 1.0000e+00]]])\n" ] } ], "source": [ "# TODO: Implement Positional Encoding\n", "\n", "class PositionalEncoding(nn.Module):\n", " def __init__(self, d_model, max_len=5000):\n", " super(PositionalEncoding, self).__init__()\n", "\n", " # Compute the positional encodings once in log space.\n", " pe = torch.zeros(max_len, d_model)\n", " position = torch.arange(0, max_len).unsqueeze(1)\n", " div_term = torch.exp(torch.arange(0, d_model, 2) *\n", " -(math.log(10000.0) / d_model))\n", " pe[:, 0::2] = torch.sin(position * div_term)\n", " pe[:, 1::2] = torch.cos(position * div_term)\n", " pe = pe.unsqueeze(0)\n", " self.register_buffer('pe', pe)\n", "\n", " def forward(self, x):\n", " x = x + self.pe[:, :x.size(1)]\n", " return x\n", "\n", "# Instantiate the class\n", "pe = PositionalEncoding(20)\n", "\n", "# Create a tensor of shape (1, 10, 20)\n", "x = torch.zeros(1, 10, 20)\n", "\n", "# Pass the tensor through the PositionalEncoding\n", "y = pe(x)\n", "\n", "print(y)\n" ] }, { "cell_type": "markdown", "metadata": { "id": "Lwe1ftzUc4h9" }, "source": [ "## Additional Resources\n", "Here are some additional resources to deepen your understanding:\n", "\n", "- [\"Illustrated Transformer\" by Jay Alammar](https://jalammar.github.io/illustrated-transformer/).\n", "- PyTorch official documentation and tutorials." ] } ], "metadata": { "colab": { "provenance": [] }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.9" } }, "nbformat": 4, "nbformat_minor": 1 }