Brian Lovin
/
Hacker News

Ask HN: Can someone ELI5 transformers and the “Attention is all we need” paper?

I have zero AI/ML knowledge but Steve Yegge on Medium thinks that the team behind Transformers deserves a Nobel.

Makes me want to better understand this tech.

Edit: thank you for some amazing top level responses and links to valuable content on this subject.

Daily Digest email

Get the top HN stories in your inbox every day.

benjismith

Okay, here's my attempt!

First, we take a sequence of words and represent it as a grid of numbers: each column of the grid is a separate word, and each row of the grid is a measurement of some property of that word. Words with similar meanings are likely to have similar numerical values on a row-by-row basis.

(During the training process, we create a dictionary of all possible words, with a column of numbers for each of those words. More on this later!)

This grid is called the "context". Typical systems will have a context that spans several thousand columns and several thousand rows. Right now, context length (column count) is rapidly expanding (1k to 2k to 8k to 32k to 100k+!!) while the dimensionality of each word in the dictionary (row count) is pretty static at around 4k to 8k...

Anyhow, the Transformer architecture takes that grid and passes it through a multi-layer transformation algorithm. The functionality of each layer is identical: receive the grid of numbers as input, then perform a mathematical transformation on the grid of numbers, and pass it along to the next layer.

Most systems these days have around 64 or 96 layers.

After the grid of numbers has passed through all the layers, we can use it to generate a new column of numbers that predicts the properties of some word that would maximize the coherence of the sequence if we add it to the end of the grid. We take that new column of numbers and comb through our dictionary to find the actual word that most-closely matches the properties we're looking for.

That word is the winner! We add it to the sequence as a new column, remove the first-column, and run the whole process again! That's how we generate long text-completions on word at a time :D

So the interesting bits are located within that stack of layers. This is why it's called "deep learning".

The mathematical transformation in each layer is called "self-attention", and it involves a lot of matrix multiplications and dot-product calculations with a learned set of "Query, Key and Value" matrixes.

It can be hard to understand what these layers are doing linguistically, but we can use image-processing and computer-vision as a good metaphor, since images are also grids of numbers, and we've all seen how photo-filters can transform that entire grid in lots of useful ways...

You can think of each layer in the transformer as being like a "mask" or "filter" that selects various interesting features from the grid, and then tweaks the image with respect to those masks and filters.

In image processing, you might apply a color-channel mask (chroma key) to select all the green pixels in the background, so that you can erase the background and replace it with other footage. Or you might apply a "gaussian blur" that mixes each pixel with its nearest neighbors, to create a blurring effect. Or you might do the inverse of a gaussian blur, to create a "sharpening" operation that helps you find edges...

But the basic idea is that you have a library of operations that you can apply to a grid of pixels, in order to transform the image (or part of the image) for a desired effect. And you can stack these transforms to create arbitrarily-complex effects.

The same thing is true in a linguistic transformer, where a text sequence is modeled as a matrix.

The language-model has a library of "Query, Key and Value" matrixes (which were learned during training) that are roughly analogous to the "Masks and Filters" we use on images.

Each layer in the Transformer architecture attempts to identify some features of the incoming linguistic data, an then having identified those features, it can subtract those features from the matrix, so that the next layer sees only the transformation, rather than the original.

We don't know exactly what each of these layers is doing in a linguistic model, but we can imagine it's probably doing things like: performing part-of-speech identification (in this context, is the word "ring" a noun or a verb?), reference resolution (who does the word "he" refer to in this sentence?), etc, etc.

And the "dot-product" calculations in each attention layer are there to make each word "entangled" with its neighbors, so that we can discover all the ways that each word is connected to all the other words in its context.

So... that's how we generate word-predictions (aka "inference") at runtime!

By why does it work?

To understand why it's so effective, you have to understand a bit about the training process.

The flow of data during inference always flows in the same direction. It's called a "feed-forward" network.

But during training, there's another step called "back-propagation".

For each document in our training corpus, we go through all the steps I described above, passing each word into our feed-forward neural network and making word-predictions. We start out with a completely randomized set of QKV matrixes, so the results are often really bad!

During training, when we make a prediction, we KNOW what word is supposed to come next. And we have a numerical representation of each word (4096 numbers in a column!) so we can measure the error between our predictions and the actual next word. Those "error" measurements are also represented as columns of 4096 numbers (because we measure the error in every dimension).

So we take that error vector and pass it backward through the whole system! Each layer needs to take the back-propagated error matrix and perform tiny adjustments to its Query, Key, and Value matrixes. Having compensated for those errors, it reverses its calculations based on the new QKV, and passes the resultant matrix backward to the previous layer. So we make tiny corrections on all 96 layers, and eventually to the word-vectors in the dictionary itself!

Like I said earlier, we don't know exactly what those layers are doing. But we know that they're performing a hierarchical decomposition of concepts.

Hope that helps!

Me1000

This was a very helpful visualization, thank you!

The "entanglement" part intuitively makes sense to me, but one bit I always get caught up on the key, query, and value matrices. In every self-attention explanation I've read/watched they tend to get thrown out there and similar to what you did here but leave their usage/purpose a little vague.

Would you mind trying to explain those in more detail? I've heard the database analogy where you start with a query to get a set of keys which you then use to lookup a value, but that doesn't really compute with my mental model of neural networks.

Is it accurate to say that these separate QKV matrices are layers in the network? That doesn't seem exactly right since I think the self-attention layer as a whole contains these three different matrices. I would assume they got their names for a reason that should make it somewhat easy to explain their individual purposes and what they try to represent in the NN.

benjismith

I'm still trying to get a handle on that part myself... But my ever-evolving understanding goes something like this:

The "Query" matrix is like a mask that is capable of selecting certain kinds of features from the context, while the "Key" matrix focuses the "Query" on specific locations in the context.

Using the Query + Key combination, we select and extract those features from the context matrix. And then we apply the "Value" matrix to those features in order to prepare them for feed-forward into the next layer.

There are multiple "Attention Heads" per layer (GPT-3 had 96 heads per layer), and each Head performs its own separate QKV operation. After applying those 96 Q+K->V attention operations per layer, the results are merged back into a single matrix so that they can be fed-forward into the next layer.

Or something like that...

I'm still trying to grok it myself, and if anyone here shed more light on the details, I'd be very grateful!

I'm still trying to understand, for example, how many QKV matrices are actually stored in a model with a particular number of parameters. For example, in a GPT-NeoX-20B model (with 20 billion params) how many distinct Q, K, and V matrices are there, and what is their dimensionality?

EDIT:

I just read Imnimo's comment below, and it provides a much better explanation about QKV vectors. I learned a lot!

ActorNightly

Its basically almost the same as convolution with image processing. For example, you take the 3 channel rgb value of a single pixel, do some math on it with the values of the surrounding pixels with weights, which gives you some value(s). Depending on the dimensions of everything, you can end up with a smaller dimension output, like a single 3 channel RGB value, or a higher dimension output (i.e for a 5x5 kernel, you can end up with a 9x9 output)

The confusing part that doesn't get mentioned is that the input vectors (Q, K, V) are weighted, i.e they are derived from the input with the standard linear transformation where y = A*x+b, where x is the input word, A is the linear layer matrix, and b is the bias. Those weighs are the things that are learned through the training process.

detrites

That was incredible. Thank you! If you made it into an article with images showing the mask/filter analogy, it might be one of the best/most unique explanations I've seen. Love the ground-up approach beginning with data's shape.

Reminded me of the style of a book on machine learning. If anyone liked this explanation, you may appreciate this book:

https://www.amazon.com/Applied-Machine-Learning-Engineers-Al...

Too

If it only generates one word at a time and then repeat the process again, how does it know when to stop?

It feels like this method would create endless ramblings. But we all know you can ask Chatgpt to “summarize in one sentence” and it pulls it off. When speaking yourself you sort of have to think how to finish a sentence before you start it, to explain something cohesively, surely there must be something similar in the AI?

joshhart

One of the “words” is a stop token, which represents the ending of text. So you can say the thing that maximizes coherence is to stop right then.

throw310822

I have a very dumb question, I'll just throw it here: I understand word embeddings and tokenisation- and the value of each; but how can the two work together? Are embeddings calculated for tokens, and in that case, how useful are they, given that each token is just a fragment of a word, often with little or no semantic meaning?

lhnz

I've heard that nowadays subword/token embeddings are learned during the training phase, and that they are useful for reconstructing the embeddings of words that contain them, and in fact allow the model to handle typos like "aple" (instead of "apple").

coppsilgold

The way transformers operate is by transforming the embedding space through each layer. You could say that all the "understanding" is happening in that high dimensional space - that of a single token, but multiplied by the number of tokens. Seeding the embedding space with some learned value for each token is helpful. Think of it as just a vector database: token -> vector.

Decoder-only architectures (such as GPT) mask the token embedding interaction matrix (attention) such each token embedding and all subsequent transformations only have access to preceeding token embeddings (and transforms). This means that on output, only the last transformed token embedding has the full information of the entire context - and only it is capable of making predictions for the next token.

This is done so that during training, you can simultaneously make 1000s (context size) of predictions - every final token embedding transform is predicting the next token. The alternative (Encoder architecture, where there is no masking and the first token can interact with the final token) would result in massively inefficient training for predicting the next token as each full context can only make a single prediction.

tomhamer

Disclaimer - someone from Marqo here.

Marqo supports E5 models: https://github.com/marqo-ai/marqo

hackernewds

why are the words cols and properties are rows. seems counter intuitive

parpfish

just tilt your head 90 degrees and it'll be fine.

this is rows/columns from a math/matrix/tensor perspective where they are the arbitrary first and second dimensions of a data-containing object.

it's not rows/columns from a database perspective where you expect columns to define a static schema and rows to be individual records.

noman-land

Thank you for this.

throwawaymaths

The Yannic kilcher review is quite good.

https://youtu.be/iDulhoQ2pro

I can't ELI5 but I can ELI-junior-dev. Tl;dw:

Transformers work by basically being a differentiable lookup/hash table. First your input is tokenized and (N) tokens (this constitutes the attention frame) are encoded both based on token identity and position in the attention frame.

Then there is an NxN matrix that is applied to your attention frame "performing the lookup query" over all other tokens in the attention frame, so every token gets a "contextual semantic understanding" that takes in both all the other stuff in the attention frame and it's relative position.

Gpt is impressive because the N is really huge and it has many layers. A big N means you can potentially access information farther away. Each layer gives more opportunities to summarize and integrate long range information in a fractal process.

Two key takeaways:

- differentiable hash tables

- encoding relative position using periodic functions

NB: the attention frame tokens are actually K-vectors (so the frame is a KxN matrix) and the query matrix is an NxNxK tensor IIRC but it's easier to describe it this way

dpcx

I appreciate the explanation, but I don't know what junior-dev would understand most of this. I may be just a web developer, but I couldn't understand most of this. I'd still have to read for 30m to grok it all.

throwawaymaths

Yeah sorry, it still requires math and probably some exposure to ML basics.

shadowgovt

I think one hole in the description for simplicity is that "differentiable" it's not an adjective that applies to hash tables.

Differentiable relative to what? What is (x) in the d(hashtable)/d(x) equation?

sva_

One thing that might be worth pointing out is that the transformer architecture owes a great deal of its success to the fact that it can be implemented in a way that it can be massively parallelized in a very efficient manner.

throwawaymaths

Compared to rnns... maybe? The big nxn is really a killer.

I don't know how to judge parallelizability of different DNN models, you're comparing apples to oranges

Salgat

When you train a transformer, you're training what the next expected token is. You can train all positions of the sequence each in parallel rather than having to sequentially build up the memory state as you generate the sequence with an LSTM. Mind you the inference portion of a transformer is still sequentially bottlenecked since you don't know what the output sequence is supposed to be.

kenjackson

What does it mean for a lookup/hash table to be differentiable?

tomp

I'm not a ML expert but I know a bit about math.

It's "differentiable" in the same way that e.g. the "jump function" (Heaviside step function) is differentiable (not as a function from real numbers to real numbers, but as a distribution). It's derivative is the "point impulse function" (Dirac delta function), which, again, is a distribution, not a real function.

Distributions are nicely defined in math, but can't really be operated with numerically (at least not in the same way as real/float functions), but you can approximate them using continuous functions. So instead of having a function jump from 0 to 1, you "spread" the jump and implement it as a continuous transition from e.g. `0-epsilon` to `0+epsilon` for some tiny epsilon. Then you can differentiate it as usual, even numerically.

Similarly, hash table lookup is a dis-continuous function - the result of `hash.get(lookup)` is just `value` (or `null`). To make it continuous, you "spread" the value, so that nearby keys (for some definition of "nearby") will return nearby values.

One way to do this, is to use the scalar product between `lookup` and all keys in the hashtable (normalized, the scalar product is close to 1 if the arguments are "nearby"), and use the result as weights to multiply with all values in the hashtable. That's what the transformer does.

hoosieree

Thanks for this explanation. I couldn't wrap my mind around the "differentiable hash table" analogy, but "distribution of keys" -> "distribution of values" starts to click.

I'm not an ML expert either but I have taken graduate level courses and published papers with "machine learning" in the title, so I feel like I should be able to understand these things better. The field just moves so fast. It's a lot of work to keep up. Easy-to-digest explanations like this are underrated.

kylewatson

thank you. This made it click.

taylorius

Differentiable in this context, means that it can be learned by gradient descent, which uses derivatives to adjust a neural network model's parameters to minimise some error measure. As for how that applies to a hash function, I think the lookup gives some sort of weighting for each possible output, for each input, with the largest weightings corresponding to a "match". But tbh I'm not certain on that last part...

throwawaymaths

> As for how that applies to a hash function, I think the lookup gives some sort of weighting for each possible output, for each input, with the largest weightings corresponding to a "match"

Note that in general one "matches" more than one thing in which case (IIRC) the result will be a weighted linear combination of the query results.

I do think that this collapses to a datastructure equivalent in function to the "normal hash table" (but not in performance, lol) in the degenerate case where every input yields strictly a single result with zero activation in all other results, so it's not invalid to call it a differentiable lookup table.

throwawaymaths

You can take a (calculus) derivative of the mathematical construct that represents the lookup table. And yeah, it is very not obvious how a "lookup/hash table" could be differentiable, based on how it's implemented say, in a leetcode exercise. That's part of the genius of the transformer.

meowkit

Going to go out on a limb and say they are probably referring to the gradient calculus required for updating the model.

https://en.wikipedia.org/wiki/Differentiable_programming

See automatic differentiation.

throwawaymaths

Correct, but note that if you subject a standard hash table algo to AD it won't magically become a transformer. (Hashes in the "normal construction" are discrete functions and thus aren't really continuous or differentiable, neither are lookup tables)

coppsilgold

What actually happens is that each vector grabs a fraction of each other vector and adds it to itself (simplification omitting some transforms along the way). Equating this with the concept of a hash function does not make sense in my opinion but some authors do it anyway.

It's differentiable because how much of a fraction to grab is the result of a simple dot product followed by a softmax.

visarga

It's not really doing hashing, which is random, but instead it makes meaningful connections, like connecting an adjective to its noun, or a pronoun to a previous reference of the name. That is why it is called "Attention is all you need" and not "Hashing is all you need". Attention includes some meaning, it is attentive to some aspect or another.

throwawaymaths

Hashing does not have to be random, it just has to compress the domain. Modulo number is a perfectly valid hash for some use cases.

Edit: just looked it up, doesn't even have to compress the domain, identity hash is a thing.

theGnuMe

The lookup/hashtable can be viewed as a matrix which is accessed by multiplying it with a vector.

weinzierl

I wanted to ask the same and especially I've always been wondering: How is the meaning of aforementioned 'differentiable' related to the same term in math?

dcre

Not an expert (so this could well be slightly off), but here is a physics analogy:

You're rolling a ball down a plank of wood and you want it to roll off the plank and continue to a certain spot on the ground. You know that if the ball rolls past the target, you have to angle the ramp lower. If it fails to reach the target, you angle the ramp higher. If it goes past by a lot, you make a big change to the ramp angle. If it goes past by a little, you make a small change to the ramp angle. In this way your error tells you something about the adjustment you need to make to the system to produce the desired output.

Think of a function from the ramp angle (network weights) to how far the ball lands from the target (the error). The derivative of this function tells you what kind of change in angle results in what kind of change in the error, and you can use that to update the ramp until you hit the target. To say that the model is differentiable is to say that it's possible to infer from the error what kind of changes you need to make to the weights.

kccqzy

Same thing. You need everything to be differentiable in order to run gradient descent. To first approximation, training a neural network is just gradient descent.

visarga

The neural net is just a math function, continuous even, fully differentiable in all input points. In order to "learn" anything we compute gradients towards the function parameters. They get "nudged" slightly towards a better response, and we do this billions of times. It's like carving a raw stone block into a complex scene. If you put your data into the system it flows towards the desired output because the right path has been engraved during training.

<rant> This explains a bit how neural nets work but from this to chatGPT is another whole leap. You'd have to assign some of the merits of the AI to the training data itself, it's not just the algorithm for learning, but what is being learned that matters. The neural net is the same, but using 1T tokens of text is making it smart. What's so magic about this data that it can turn a random init into a language programmable system? And the same language information make a baby into a modern human, instead of just another animal. </>

rekttrader

I came here to post this video. It’s a great primer on the topic and it gives you ideas to prompt gpt and have it output more.

It’s how I got an understanding of beam search, a technique employed in some of the response building.

jimkleiber

I had to look up "tl;dw" and realized it meant "too long; didn't watch" and not my first AI-laden instinct of "too long, didn't write" :-D

ogoparootbbo2

what is a differentiable hash table? I understand differentiation, I don't understand differential hash table ... does that mean for every tiny gradient of a key, a corresponding gradient in value can be expected?

legalizemoney

Having read the paper myself, I'm impressed with the quality of your explanation. Well done!

zorr

How does N relate to the number of parameters that is frequently mentioned?

throwawaymaths

In my screed, N is the attention width. (How many token it looks at at a time) number of parameters is O(KxNxNxL) where k is the vector size of your tokens, and l is the # of layers. There are other parameters floating around, like in the encoder and decoder matrices, but the NXN matrix dominates.

8thcross

This is an awesome explanation. You guys are the real heroes

PeterisP

ELI5 is tricky as details have to be sacrificed, but I'll try.

An attention mechanism is when you want a neural network to learn the function of how much attention to allocate to each item in a sequence, to learn which items should be looked at.

Transformers is a self-attention mechanism, where you ask the neural network to 'transform' each element by looking at its potential combination with every other element and using this (learnable, trainable) attention function to decide which combination(s) to apply.

And it turns out that this very general mechanism, although compute-intensive (it considers everything linking with everything, so complexity quadratic to sequence length) and data-intensive (it has lots and lots of parameters, so needs huge amounts of data to be useful) can actually represent many of things we care about in a manner which can be trained with the deep learning algorithms we already had.

And, really, that's the two big things ML needs, a model structure where there exists some configuration of parameters which can actually represent the thing you want to calculate, and that this configuration can actually be determined from training data reasonably.

tworats

The Illustrated Transfomer ( https://jalammar.github.io/illustrated-transformer/ ) and Visualizing attention ( https://towardsdatascience.com/deconstructing-bert-part-2-vi... ), are both really good resources. For a more ELI5 approach this non-technical explainer ( https://www.parand.com/a-non-technical-explanation-of-chatgp... ) covers it at a high level.

dominickramer

Suppose someone asked you to complete the sentence:

“After I woke up and made breakfast, I drank a glass of …”

In America one might say the most likely next words are “orange juice”, or “apple juice” but not “sports car” which has nothing to do with the sentence.

Ultimately this is what language models do, given a sequence of data (in this case words) predict the most likely next word(s).

For attention, when you read the sentence, which words stood out as more important? Probably woke up, breakfast, and glass while the words after, I, and made were less important to completing the sentence.

That is, you paid more attention to the important words to understand how to complete the sentence.

The “attention mechanism” in language models is a way to let the models learn which words are important in sentences and pay more attention to them too when completing sentences, just like a person would do as in the example above.

Further, it turns out this attention mechanism lets the models do lots of interesting things even without other fancy model techniques. That is “attention is all you need”.

nafey

When one says "attention is all you need" the implication is that some believe that you need something more than just attention. What is that something which has been demonstrated as unneeded? Is it a theory of how language works?

trenchgun

Recursion. Before transformers attention was used in recurrent neural networks. "attention is all you need" showed that you can just drop the recursion and just use attention, and the outcome is that you get a very nicely parallelizable architechture, allowing more efficient training.

Sammi

Finally someone cut to chase. Thanks you!

FranklinMaillot

Those Computerphile videos[0] by Rob Miles helped me understand transformers. He specifically references the "Attention is all you need" paper.

And for a deeper dive, Andrej Kharpaty has this hands-on video[1] where he builds a transformer from scratch. You can check-out his other videos on NLP as well they are all excellent.

[0] https://youtu.be/rURRYI66E54, https://youtu.be/89A4jGvaaKk

[1] https://youtu.be/kCc8FmEb1nY

WithinReason

Well here is my (a bit cynical) take on it.

In the beginning, there was the matrix multiply. A simple neural network is a chain of matrix multiplies. Let's say you have your data A1 and weights W1 in a matrix. You produce A2 as A1xW1. Then you produce A3 as A2xW2, and so on. There are other operations in there like non-linearities (so that you can actually learn something interesting) and fancy batch norms, but let's forget about those for now. The problem with this is, it's not very expressive. Let's say your A1 matrix has just 2 values, and you want the output to be their product. Can you learn a weight matrix that performs multiplication of these inputs? No you can't. Multiplication must be simulated by piecing together piecewise linear functions. To perform multiplication, the weight matrix W would also need to be produced by the network. Transformers do basically that. In the product A*W you replace A with (AxW1), W with (AxW2), and multiply those together: (AxW1)x(AxW2) And then do it once more for good measure: (AxW1)x(AxW2)x(AxW3). Boom, Nobel prize. Now your network can multiply, not just add. OK it's actually a bit more complicated, there is for example a softmax in the middle to perform normalisation, which in general helps during numerical optimisation: softmax((AxW1)x(AxW2))x(AxW3). There are then fancy explanations that try to retrospectively justify this as a "differentiable lookup table" or somesuch nonsense, calling the 3 parts "key", "query" and "value", which help make your paper more popular. But the basic idea is not so complicated. A Transformer then uses this operation as a building block (running them in parallel an in sequence) to build giant networks that can do really cool things. Maybe you can teach networks to divide next and then you get the next Nobel prize.

mcdougal

Are transformers

- a hack devised/stumbled upon by AI workers or

- a theoretical concept concocted by a mathematician who has been thinking about what NNs do or,

- a set of techniques pipe-lined together by clever programmers who work with NNs? Or...

- something else?

I mean, if transformers really do something rational, then there should be a straightforward rational mathematical statement of the problem, a clear and clean expression of what they do!

What I see instead is a lot of complex cumbersome description and terminological noise: no clear problem statement, lots of steps, lots of moving parts, and downhill from there.

Now I'd be the first to admit that, if I believed I understood and could reproduce intelligence or language [and is intelligence merely language? An argument to that effect can be made - see Helen Keller], then if required I'd be prone to provide dense, noisy and incorrect explanations aplenty to potential competitors and even to honest inquisitive people. I would do that b/c revelation of the truth would destroy my competitive advantage. IOW I see every reason for ChatGPT et al developers and corporations to guide outsiders astray at this time.

Developing something like ChatGPT is like running an exposed Manhattan Project - everyone wants Da Bomb and you don't want them to have it - instead you want to lead them completely astray. Seems to be succeeding: certainly as far as I'm concerned.

Here's a simple test: has anyone within reach of my words made a version of these systems from scratch that does anything like the fullblown ChatGPT does and that (s)he will reveal?

Current times would indicate that we won't learn how it all works until someone leaks it or someone else figures it out (like happened to Heisenberg and Schrodinger). My bet is on the latter. And that's the guy/gal who should get the Fields medal or Nobel Prize. [I'd bet a mathematician will do it].

Meanwhile as we twiddle our matrices much effort must be afoot to infiltrate the ChatGPT working groups and get the goods and also to keep the current ChatGPT worker bees from flying off to other hives and revealing secrets. This may be one of the few times when tech workers' jobs seriously shorten their lives.

- Wandering in the desert...

seydor

Are there any papers using more than 3 linearly transformed vectors?

devit

It works like this:

First, convert the input text to a sequence of token numbers (2048 tokens with 50257 possible token values in GPT-3) by using a dictionary and for each token, create a vector with 1 at the token index and 0 elsewhere, transform it with a learned "embedding" matrix (50257x12288 in GPT-3) and sum it with a vector of sine and cosine functions with several different periodicities.

Then, for each layer, and each attention head (96 layers and 96 heads per layer in GPT-3), transform the input vector by query, key and value matrices (12288x128 in GPT-3) to obtain a query, key and value vector for each token. Then for each token, compute the dot product of its query vector with the key vectors of all previous tokens, scale by 1/sqrt of the vector dimension and normalize the results so they sum to 1 by using softmax (i.e. applying e^x and dividing by the sum), giving the attention coefficients; then, compute the attention head output by summing the value vectors of previous tokens weighted by the attention coefficients. Now, for each token, glue the outputs for all attention heads in the layer (each with its own key/query/value learned matrices), add the input and normalize (normalizing means that the vector values are biased and scaled so they have mean 0 and variance 1).

Next, for the feedforward layer, apply a learned matrix, add a learned vector and apply a ReLU (which is f(x) = x for positive x and f(x) = kx with k near 0 for negative x), and do that again (12288x49152 and 49152x12288 matrices in GPT-3, these actually account for around 70% of the parameters in GPT-3), then add the input before the feedforward layer and normalize.

Repeat the process for each layer, each with their own matrices, passing the output of the previous layer as input. Finally, apply the inverse of the initial embedding matrix and use softmax to get probabilities for the next token for each position. For training, train the network so that they are close to the actual next token in the text. For inference, output a next token according to the top K tokens in the probability distribution over a cutoff and repeat the whole thing to generate tokens until an end of text token is generated.

probably_wrong

I'll throw my hat in the ring.

A transformer is a type of neural network that, like many networks before, is composed of two parts: the "encoder" that receives a text and builds an internal representation of what the text "means"[1], and the "decoder" that uses the internal representation built by the encoder to generate an output text. Let's say you want to translate the sentence "The train is arriving" to Spanish.

Both the encoder and decoder are built like Lego, with identical layers stacked on top of each other. The lowest lever of the encoder looks at the input text and identifies the role of individual words and how they interact with each other. This is passed to the layer above, which does the same but at a higher level. In our example it would be as if the first layer identified that "train" and "arrive" are important, then the second one identifies that "the train" and "is arriving" are core concepts, the third one links both concepts together, and so on.

All of these internal representations are then passed to the decoder (all of them, not just the last ones) which uses them to generate a single word, in this case "El". This word is then fed back to the decoder, that now needs to generate an appropriate continuation for "El", which in this case would be "tren". You repeat this procedure over and over until the transformer says "I'm done", hopefully having generated "El tren está llegando" in the process.

The attention mechanism already existed before transformers, typically coupled with an RNN. The key concept of the transformer was building an architecture that removed the RNN completely. The negative side is that it is a computationally inefficient architecture as there are plenty of n^2 operations on the length of the input [2]. Luckily for us, a bunch of companies started releasing for free giant models trained on lots of data, researchers learned how to "fine tune" them to specific tasks using way less data than what it would have taken to train from scratch, and transformers exploded in popularity.

[1] I use "mean" in quotes here because the transformer can only learn from word co-occurrences. It knows that "grass" and "green" go well together, but it doesn't have the data to properly say why. The paper "Climbing towards NLU" is a nice read if you care about the topic, but be aware that some people disagree with this point of view.

[2] The transformer is less efficient that an LSTM in the total number of operations but, simultaneously, it is easier to parallelize. If you are Google this is the kind of problem you can easily solve by throwing a data center or two at the problem.

jimbokun

> The negative side is that it is a computationally inefficient architecture as there are plenty of n^2 operations on the length of the input

Is this the reason for the limited token windows?

probably_wrong

Yes, kinda. The transformer doesn't have a mechanism for dynamically adjusting its input size, so you need to strike a balance between the window being big enough for practical purposes but also small enough that you can still train the network.

Previous networks with RNNs could in theory receive inputs of arbitrary size, but in practice their performance decreased as the input got longer because they "forgot" the earlier input as they went on. The paper "Neural Machine Translation by Jointly Learning to Align and Translate" solved the forgetting problem by, you guessed it, adding attention to the model.

Eventually people realized that attention was all you needed (ha!), removed the RNN, and here we are.

jayalammar

I'm the author of https://jalammar.github.io/illustrated-transformer/ and have spent years since introducing people to Transformers and thinking of how best to communicate those concepts. I've found that different people need different kinds of introductions, and the thread here includes some often cited resources including:

https://peterbloem.nl/blog/transformers

https://e2eml.school/transformers.html

I would also add Luis Serrano's article here: https://txt.cohere.com/what-are-transformer-models/ (HN discussion: https://news.ycombinator.com/item?id=35576918).

Looking back at The Illustrated Transformer, when I introduce people to the topic now, I find I can hide some complexity by omitting the encoder-decoder architecture and focusing only on one. Decoders are great because now a lot of people come to Transformers having heard of GPT models (which are decoder only). So for me, my canonical intro to Transformers now only touches on a decoder model. You can see this narrative here: https://www.youtube.com/watch?v=MQnJZuBGmSQ

kartayyar

- You can develop a very deep understanding of a sequence by observing how each element interacts with each other over many sequences.

- This understanding can be encapsulated in "compressed" low dimensional vector representation of a sequences.

- You can use this understanding for many different downstream tasks, especially predicting the next item in a sequence.

- This approach scales really well with lots of GPUs and data and is super applicable to generating text.

vikp

Transformers are about converting some input data (usually text) to numeric representations, then modifying those representations through several layers to generate a target representation.

In LLMs, this means go from prompt to answer. I'll cover inference only, not training.

I can't quite ELI5, but process is roughly:

  - Write a prompt
  - Convert each token in the prompt (roughly a word) into numbers.  So "the" might map to the number 45.
  - Get a vector representation of each word - go from 45 to [.1, -1, -2, ...]. These vector representations are how a transformer understands words.  
  - Combine vectors into a matrix, so the transformer can "see" the whole prompt at once.
  - Repeat the following several times (once for each layer):
  - Multiply the vectors by the other vectors.  This is attention - it's the magic of transformers, that enables combining information from multiple tokens together.  This generates a new matrix.
  - Feed the matrix into a linear regression.  Basically multiply each number in each vector by another number, then add them all together.  This will generate a new matrix, but with "projected" values.
  - Apply a nonlinear transformation like relu.  This helps model more complex functions (like text input -> output!)
Note that I really oversimplified the last few steps, and the ordering.

At the end, you'll have a matrix. You then convert this back into numbers, then into text.

throwawaymaths

I don't think this description of attention is correct.

vikp

You mean "Multiply the vectors by the other vectors. This is attention - it's the magic of transformers, that enables combining information from multiple tokens together. This generates a new matrix."?

It's really oversimplified, as I mentioned. A more granular look is:

  - Project the vectors with a linear regression.  In decoder-only attention (what we usually use), we project the same vectors twice with different coefficients.   We call the first projection queries, and the second keys.  This transforms the vectors linearly.
  - Find the dot product of each query vector against the key vectors (multiply them)
  - (training only) Mask out future vectors, so a token can't look at tokens that come after it
  - At this point, you will have a matrix indicating how important each query vector considers each other vector (how important each token considers the other tokens)
  - Take the softmax, which both ensures all of the attention values for a vector sum to 1, and penalizes small attention values
  - Use the softmax values to get a weighted sum of tokens according to the attention calc.
  - This will turn one vector into the weighted sum of the other vectors it considers important.
The goal of this is to incorporate information from multiple tokens into a single representation.
Daily Digest email

Get the top HN stories in your inbox every day.

Ask HN: Can someone ELI5 transformers and the “Attention is all we need” paper? - Hacker News