• Meant to mimic cognitive attention
    • Picks out relevant bits of information
    • Use gradient descent
  • Used in 90s
    • Multiplicative modules
    • Sigma pi units
    • Hyper-networks
  • Draw from relevant state at any preceding point along sequence
    • Addresses RNNs vanishing gradient issues
    • LSTM tends to poorly preserve far back knowledge
  • Attention layer access all previous states and weighs according to learned measure of relevance
    • Allows referring arbitrarily far back to relevant tokens
  • Can be addd to RNNs
  • In 2016, a new type of highly parallelisable decomposable attention was successfully combined with a feedforward network
    • Attention useful in of itself, not just with RNNs
  • Transformers use attention without recurrent connections
    • Process all tokens simultaneously
    • Calculate attention weights in successive layers

Scaled Dot-Product

  • Calculate attention weights between all tokens at once
  • Learn 3 weight matrices
    • Query
      • WQW_Q
    • Key
      • WKW_K
    • Value
      • WVW_V
  • Word vectors
    • For each token, ii, input word embedding, xix_i
      • Multiply with each of above to produce vector
    • Query Vector
      • qi=xiWQq_i=x_iW_Q
    • Key Vector
      • ki=xiWKk_i=x_iW_K
    • Value Vector
      • vi=xiWVv_i=x_iW_V
  • Attention vector
    • Query and key vectors between token ii and jj
    • aij=qikja_{ij}=q_i\cdot k_j
    • Divided by root of dimensionality of key vectors
      • dk\sqrt{d_k}
    • Pass through softmax to normalise
  • WQW_Q and WKW_K are different matrices
    • Attention can be non-symmetric
    • Token ii attends to jj (qikjq_i\cdot k_j is large)
      • Doesn’t imply that jj attends to ii (qjkiq_j\cdot k_i can be small)
  • Output for token ii is weighted sum of value vectors of all tokens weighted by aija_{ij}
    • Attention from token ii to each other token
  • Q,K,VQ, K, V are matrices where iith row are vectors qi,ki,viq_i, k_i, v_i respectively Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q,K,V)=\text{softmax}\left( \frac{QK^T}{\sqrt{d_k}} \right)V
  • softmax taken over horizontal axis