2022-2-27: Flash, Expert Choice Routing, Effective MoE, Merging inputs and tokens
Mixture-of-Experts with Expert Choice Routing
Instead of choosing the top-k experts for each token, you choose the top-k tokens per expert. Seems to work even better. I actually started coding this independently last month (scooped!), and the subtleties are: 1) it makes your routing function super cheap, which is great, but 2) you end up summing different numbers of activation tensors for each token, which is hard to make efficient. You can embedding_bag
this, but even constructing the indices is a pain.
Designing Effective Sparse Expert Models
Use top-2 routing, penalize squared l2 norm of logits in routing functions, don’t get rid of multiplicative interactions like RMSNorm and GEGELU (removing them fixes frequent training divergence, but also hurts quality). MoE models do worse downstream because they overfit; try finetuning with a smaller batch size and higher learning rate, and maybe finetuning only a subset of the params (and make sure it’s not only the MoE params, because that sucks). Also observed that larger capacity factors always yield higher quality (at least pretraining). Also multiplying by a learned bias (instead of adding) in the GEGELU after the first linear in the FFN helps. Bunch of negative results in the final appendix, though I didn’t think any were that interesting.
Benchmarking the Linear Algebra Awareness of TensorFlow and PyTorch
PyTorch and TF aren’t as smart as Eigen or other libs that exploit associativity + distributivity in sequences of linear algebra ops. They especially fail at loop hoisting, common-subexpression elimination, and not computing unnecessary stuff when you only need slices of the result. Unclear how often this matters; they don’t have a benchmark of workloads or anything, just microbenchmarks highlighting failure cases. These limitations are well-known, but nice to see the “missing” features vs lower-level libs all enumerated in one place.
Y-Tuning: An Efficient Tuning Paradigm for Large-Scale Pre-Trained Models via Label Representation Learning
Trying to finetune large language models better. I found this pretty hard to make sense of, but they seem to be feeding in learnable representations of each level as input tokens, then jointly attending to those and the input representation from a frozen model, and then just argmax-ing across final label token representations to get a prediction. For some reason they train it with triplet loss. Results look like an improvement in accuracy compared to alternatives, and they have the nice property that they can just run inference through the pretrained model once in the first epoch, and reuse the outputs.
ProxSkip: Yes! Local Gradient Steps Provably Lead to Communication Acceleration! Finally!
New SGD variant for federated learning (and distributed optimization with expensive communication more generally). Basically like FedAvg, which averages weights instead of gradients, but 1) only communicates with some prob p < 1, and 2) shrinks weights towards the last global avg at each step, instead of having them “jump” straight there. Can probably derandomize this for better performance empirically, at the cost of the nice provable convergence rate. I see federated optimization as likely to eventually be relevant as super large-scale training starts hitting the same communication bottlenecks.
DataMUX: Data Multiplexing for Neural Networks
Extreme Mixup used as a speedup technique. Combine a bunch of inputs into one input (of the same size), and then add a module at the end that untangles them and makes a prediction for each one. Input multiplexer is just random projection + elemwise averaging, after doubling the feature space. Can demultiplex N inputs by feeding in a learned query that tells the demux which input to construct the final embedding for. I get the sense they’re either not thinking through what’s happening in the mux or not describing it well, because they talk about appending a position-specific token to each input before the mux, but since they’re applying a linear transform and then averaging all the representations, this is equivalent to just doubling the input space and adding in a fixed bias. Also, they have to pretrain it with an input reconstruction task to get the demux to work. Hard to assess speed vs quality tradeoff here because they report 10-20x speedups, but also huge 2-8% accuracy drops. But pretty interesting that they got this to work at all.
When, Why, and Which Pretrained GANs Are Useful?
“initializing [with] a pretrained checkpoint primarily affects the model's coverage rather than the fidelity of individual samples”; and “we describe a simple recipe to choose an appropriate GAN checkpoint that is the most suitable for finetuning to a particular target task.” Argues you should usually use an ImageNet-pretrained GAN.
Cyclical Focal Loss
Another cyclic-something paper by the superconvergence guy. Proposes to generalize focal loss, which upweights loss from low-confidence predictions, to instead focus on high-confidence predictions at start and end of training, and low-confidence ones in the middle. Sorta maybe better results on some small experiments, but simple enough to code that it might be worth trying. Also uses the OLTR variants of ImageNet and Places, which could be interesting benchmarks for accuracy on heavily imbalanced datasets.
Survey on Large Scale Neural Network Training
Big detailed survey paper of distributed training schemes. Lots of tables distilling differences between approaches. A good reference for ramping up on this area.
Enabling On-Device Smartphone GPU based Training: Lessons Learned
Apparently snapdragon GPUs have garbage memory bandwidth, to the point that training on the mobile CPU is faster. Data movement on snapdragon GPU takes up to 91% of the time. They wrote a bunch of custom kernels that got them to ~50 GFLOPs with GPU vs ~25 with CPU, which might be useful ballpark numbers to keep in mind when reasoning about mobile inference. Kind of skeptical of this though based on conversations with people who’ve done a lot of mobile inference optimization.
Deconstructing Distributions: A Pointwise Framework of Learning
They study the predictions on individual points as a function of various aspects of the predictor, most notably its overall accuracy. Besides the unsurprising finding that samples vary in difficulty, they find that some points are consistently negatively correlated with overall accuracy. I.e., less accurate or partially-trained models are more likely to get them right. This flies in the face of most statistical learning theory under standard assumptions. They also put forth the idea of models learning specific “skills”, like examining global shape rather than local texture; e.g., recognizing a coffee cup (global shape) that has a picture of a dalmation on it (striking but misleading texture). Main takeaway is that storing metadata about specific samples over time can yield useful insights; and there might be some set of domain-specific “skills” one could identify and use to help understand models better.
Loss as the Inconsistency of a Probabilistic Dependency Graph: Choose Your Model, Not Your Loss Function
Interesting but esoteric stuff I’d like to spend 20 hours ramping up on, but probably never will. Mostly listing this because this dude has the best latex skills I’ve ever seen and I want to copy-paste his macros into all my future papers.
A New Generation of Perspective API: Efficient Multilingual Character-level Transformers
Google deployed a multi-lingual character-level transformer for toxic comment classification. Written by the charformer people, which was a paper that added a learnable tokenizer. (Learnable tokenizer was complicated and didn’t seem to have a unifying idea, so I don’t fully understand it). Makes intuitive sense that operating on UTF8 bytes makes a ton of sense for sharing a model across languages, handling intra-message language switching, and dealing with emoji well. Mostly suggests that we might be able to get rid of tokenizers and simplify language model APIs.
First is Better Than Last for Training Data Influence
Presents evidence that, when assessing the influence of specific training samples on model parameters, you should focus on gradients of the word embeddings, rather than gradients for the output layer.
⭐ Learning to Merge Tokens in Vision Transformers
Halfway through the (encoder) network, they reduce the number of patches to 8 using 8 learned query vectors, similar to the Perceiver. It’s always 8, even if eval is with higher-res images. 2x FLOP reduction, and 1.6x speedup, as evaluated on JFT300M and when fine-tuned for 10-shot ImageNet-1k.
⭐ Auto-scaling Vision Transformers without Training
They come up with a ViT design and scaling strategy based on features extracted from 87 offline ImageNet training runs, rather than based on interactively training each architecture their policy wants to try. The main feature they use is derived from feeding in a particular input that traces out a 2d unit circle in high-dimensional space, and looking at how much the length changes at the output (I have no idea how they chose this input generation function, or arrived at the three features they chose to evaluate). They also use the NTK condition number as a feature (ratio of largest and smallest eigenvals of NTK matrix). So basically they have this accuracy proxy they can compute for a given model that’s way cheaper than running the training. The auto-scaling uses the best small architecture they could find, and then increases the width and depth based on the same proxy. They run their full NAS search in 7 V100-hours. They also propose to use progressive resizing during training, but it basically doesn’t work (-1% ImageNet acc at 2x time reduction) despite them seemingly tuning the schedule.
⭐ Transformer Quality in Linear Time
Instead of alternating scaled self-attention and FFN blocks, just do Gated Attention Unit over and over again. Elemwise multiplies projection of input with Relu^2 attention. Latter also uses rank-1 Q and K projection matrices.
The part that isn’t captured in the above is they also use a linear attention variant that lets them do autoregressive training much more quickly. It’s quadratic within non-overlapping sequence chunks of a fixed size, so you can update the attention state in constant time as you move down the sequence, but it doesn’t suck. Transformer+ = Vanilla transformer + RoPE. Transformer++ = Vanilla transformer + RoPE + GLU. “all models are implemented in the same codebase to ensure identical tokenizer and hyper-parameters for training and evaluation.”