AI RESEARCH

FlashAttention (FA1–FA4) in PyTorch - educational implementations focused on algorithmic differences [P]

r/MachineLearning

I recently updated my FlashAttention-PyTorch repo so it now includes educational implementations of FA1, FA2, FA3, and FA4 in plain PyTorch. The main goal is to make the progression across versions easier to understand from code. This is not meant to be an optimized kernel repo, and it is not a hardware-faithful recreation of the official implementations. The point is to expose the algorithmic ideas and design changes without immediately going deep into CUDA/Hopper/Blackwell-specific details.