Blog – PyTorch https://pytorch.org Wed, 25 Feb 2026 00:45:25 +0000 en-US hourly 1 https://wordpress.org/?v=6.9 https://pytorch.org/wp-content/uploads/2024/10/cropped-favicon-32x32.webp Blog – PyTorch https://pytorch.org 32 32 Enhancing Multimodal Training and Memory Efficiency with DeepSpeed https://pytorch.org/blog/enhancing-multimodal-training-and-memory-efficiency-with-deepspeed/ Wed, 25 Feb 2026 00:45:25 +0000 https://pytorch.org/?p=47565 Overview

This blog walks through two crucial DeepSpeed updates: (1) a PyTorch-identical backward API that enables efficient training of multimodal, multi-component models (including non-scalar backward calls), and (2) low-precision model training that significantly reduces peak memory, especially.

For multimodal workloads, like combining a vision encoder with an LLM, training loops can become complex and multi-component. The first update introduces a PyTorch-identical backward API that makes writing such loops straightforward, enabling sophisticated parallelism schemes with simple, clean code, while DeepSpeed transparently manages various performance optimizations. As one example, the flexibility of the API enabled disaggregated hybrid parallelism, achieving a 30% speedup for multimodal AI model training while making model development with DeepSpeed feel closer to “vanilla PyTorch”.

Meanwhile, for LLM fine-tuning, a new option to keep all model states (parameters, gradients, and optimizer states) in lower-precision, such as BF16 or FP16, drastically reduces the memory footprint, allowing researchers to train larger models on more constrained hardware. Low-precision training is highly beneficial across a wide range of applications, including supervised fine-tuning (SFT), reinforcement learning (RL), and multimodal training. Our experiment showed 40% peak memory reduction while keeping numerical stability (benchmarking script). The numerical stability is achieved through integration with torch.autocast, which ensures the quality of the model is maintained.

The remainder of this blog will elaborate on how these updates directly facilitate the development of cutting-edge training workloads.

1. PyTorch-identical backward API

DeepSpeed now supports PyTorch’s native backward() syntax while preserving all its optimizations. Traditionally, DeepSpeed’s training loop relied on the engine’s backward API:

loss = model_engine(batch)
model_engine.backward(loss)
model_engine.step()

The engine’s backward API was sufficient for traditional pretraining and fine-tuning pipelines. However, recent complex training pipelines require more flexibility. There were two major limitations:

  1. It only accepted a scalar loss.
  2. You had to call model_engine.backward(loss), rather than using the usual PyTorch loss.backward() style.

Due to these constraints, users could not simply implement patterns that vanilla PyTorch allows. Here are some examples:

# 1. Combine multiple models and losses
output1 = model1(batch1)
output2 = model2(batch2)
loss = criterion(output1, output2)
loss.backward()

# 2. Define a loss function separately from the main model
output = model(batch)
loss = loss_fn(output)
loss.backward()

# 3. Call backward through non-scalar tensors with custom gradients
output = model(batch)
output.backward(grad)

DeepSpeed Engine was able to handle these use cases using internal APIs; however, that required significant code changes and could easily introduce bugs. With the addition of PyTorch-identical backward API, we can now use the same code as native PyTorch while keeping DeepSpeed’s powerful optimizations, including ZeRO and offloading.

One example use case for the PyTorch-identical backward API is disaggregated hybrid parallelism for multimodal models using Ray. In this training pipeline, two Ray Actor groups handle the vision encoder and the LLM separately. On a backward pass, the LLM passes a gradient to the vision encoder, and the vision encoder calls the backward function with that gradient. However, because the gradient is a non-scalar tensor, such a use case wasn’t officially supported by DeepSpeed APIs. The disaggregated hybrid parallelism demonstrates that the flexibility of the backward API combined with DeepSpeed’s optimization and DeepSpeed-Ulysses (highly efficient sequence parallelism), achieves 30% speedup in training.

Below is the pseudo-code for the two models running on different actors. Since they run in different processes, we pass gradients via Ray actor communication. As seen here, the gradient of the vision embedding is a non-scalar tensor. Although this code is identical to the PyTorch API, it will activate various DeepSpeed optimizations based on your configuration.

# Runs on LLM actors
def text_backward_step(self):
# ...
  self.loss.backward()
  return self.vision_embeddings.grad.detach().clone()

# Runs on Vision actors
def vision_backward_step(self, vision_embedding_grad):
  self.vision_output.backward(gradient=vision_embedding_grad)

Check out the repository for the complete training pipeline.

2. Memory-efficient low-precision model states

You can now keep all model states (parameters, gradients, and optimizer states) in BF16 or FP16, significantly reducing memory consumption.

Traditionally, DeepSpeed’s mixed precision keeps FP32 master parameters, gradients, and optimizer states, which is technically safer but memory-intensive. While DeepSpeed has supported torch.autocast via configuration (see the API documentation), the lack of an option to bypass creating FP32 states limited the trainability of large models on constrained hardware. In practice, many training workloads converge stably without FP32 states.

With the low-precision model states option, you can easily skip creating FP32 states and combine the low-precision option with torch.autocast support (see the document and example for configuration details). This combination drastically improves memory efficiency without sacrificing convergence.

{
...
  "zero_optimization": {
    "stage": 3,
    ...
  },
  "bf16": {
    "enabled": true,
    "bf16_master_weights_and_grads": true,
    "bf16_optimizer_states": true
  },
  "torch_autocast": {
    "enabled": true,
    "dtype": "bfloat16"
  }
}

Our example script demonstrates the significant memory savings:

Configuration Allocated Memory Peak Memory Avg Step Time
Baseline (fp32 master) 25.74 GB 31.38 GB 0.6016s
BF16 low-precision (master + opt states) 16.17 GB 18.93 GB 0.6427s

The experiment (7B model, ZeRO3, 4GPUs) demonstrated 40% reduction in peak memory. To verify that BF16 low-precision training maintains numerical stability, we trained for 1000 steps on the Wikitext-103 dataset:

Loss curve comparison

Configuration Final Loss Mean Loss
Baseline (fp32 master) 3.09 2.78
BF16 Low-Precision 3.12 2.90

Related Tests

We continuously test these new APIs in our CI, and you can see various use-case patterns in the tests.

Closing Thoughts

This DeepSpeed update delivers key advancements:

  • Enabling Complex Multimodal Workloads: The new PyTorch-identical backward API enables sophisticated multi-component training loops, such as those required for multimodal models, with simple, clean code. As one example, the PyTorch-identical backward API has enabled a 30% speedup for disaggregated hybrid parallelism.
  • Scaling to Larger Models: Low-precision model states combined with torch.autocast reduce peak memory by up to 40% without sacrificing convergence, allowing you to train larger models with the same hardware.

We are excited to see how you use the new APIs and features described in this blog post in your own training setups, and we welcome feedback and issues on GitHub as you try them out.

]]>
Accelerating Autotuning in Helion with Bayesian Optimization https://pytorch.org/blog/accelerating-autotuning-in-helion/ Tue, 24 Feb 2026 17:55:18 +0000 https://pytorch.org/?p=47554 Introduction

As introduced in a previous blog post, Helion is a high-level DSL that empowers developers to write high-performance ML kernels using a familiar PyTorch-like syntax, delegating the complex task of optimization to its autotuning engine. This autotuner explores a vast, high-dimensional space of implementation choices—block sizes, loop orders, memory access patterns—to discover configurations that maximize performance on the target hardware. As a result, Helion can achieve significant speedups over torch.compile and even highly-optimized, hand-written kernels in Triton or CuTe DSL.

However, the performance gains from auto-tuning comes with a cost: long wall-clock times. A typical autotuning session can take 10+ minutes, evaluating thousands of candidate configurations, and can even take on the order of hours for complex kernels. Since its launch, long autotuning times have consistently surfaced as a user complaint and one of the biggest pain points in the kernel development cycle. While Helion provides developers options to shorten the auto-tuning process, e.g. by reducing the number of search steps, this typically leads to a loss in kernel performance, forcing an undesirable trade-off. 

In this blog post, we discuss our ongoing efforts to improve the autotuning experience. In particular, we discuss a new search algorithm LFBO Pattern Search we developed to address these issues, which employs techniques from machine learning (ML) to improve efficiency of the autotuning engine. The search algorithm trains an ML model to intelligently filter candidate configurations, substantially reducing the number of candidates evaluated. Importantly, the model only uses data collected during the search process, and doesn’t need the user to provide any additional data. 

Using ML, we can reduce autotuning time substantially without sacrificing performance:

  • On our set of benchmark NVIDIA B200 kernels, we reduce autotuning time by 36.5% while improving kernel latency by 2.6% on average. 
  • On AMD MI350 kernels, we reduce autotuning time by 25.9% while improving kernel latency by 1.7%.

For some kernels the improvements are especially significant: we see up to a 50% reduction in wall-clock time for B200 layer-norm kernels, and even a >15% improvement in kernel latency for B200 Helion FlashAttention kernels. Due to its enhanced performance, it is the default search algorithm at the time of writing.

The Challenges of Kernel Autotuning

The autotuning engine searches through kernel configurations, benchmarking their latency and using the outcomes to determine the next set of configs to benchmark. While compiling and measuring the latency of a single configuration takes on the order of seconds, the autotuning engine typically searches through thousands of configurations to achieve the best possible performance. Finding the optimal kernel configuration is a challenging optimization problem due to several factors inherent to the design space:

  • High-Dimensional, Combinatorial Space: The space of all possible combinations of block sizes, unroll factors, etc. is high-dimensional and vast. Even a simple kernel like LayerNorm has more than 8 quadrillion (10^16) possible configurations. However, while the search space is large, only a small fraction of configs have good performance.
  • Long Compile Times: Certain kernel configurations can take a significant amount of time to compile, unnecessarily extending the autotuning process’s wall-clock time.
  • Config Errors and Timeouts: The search space can also include configs that have compilation errors, produce inaccurate results, or take too long to compile.

The previous default search strategy (Pattern Search) starts from multiple promising configurations (‘search copies’) and explores neighboring configs by exhaustively evaluating all single-parameter perturbations. While thorough, this approach is inefficient: the vast majority of neighbors offer no performance improvement, yet each is compiled and benchmarked. Furthermore, restricting moves to single-parameter changes limits the algorithm’s ability to traverse the high-dimensional search space quickly.

Likelihood-Free Bayesian Optimization Pattern Search

To address these inefficiencies, we take inspiration from Bayesian Optimization, a sub-domain of machine learning which utilizes a probabilistic surrogate model (e.g. a Gaussian Process) to intelligently select which points to evaluate next (available in libraries such as botorch and Ax). To minimize additional wall-clock time, we adapt Likelihood-Free Bayesian Optimization (LFBO), which uses a lighter-weight classification model as a surrogate. We combine the local search heuristic of Pattern Search with the LFBO classifier model to filter only the most promising candidates to benchmark, instead of exhaustive search.

The LFBOPatternSearch algorithm is as follows:

  1. Similar to PatternSearch, we first benchmark a set of randomly generated configs, and identify a small set of the most promising configurations (‘search copies’).
  2. We generate candidates from the search copies, by making random perturbations across multiple parameters, exploring more widely than PatternSearch.
  3. We train a classification model (RandomForest) on latency data collected so far. Instead of predicting latency directly, we predict a binary label indicating whether the config is in the top 10% in terms of latency.
  4. We rank the candidates based on ML model predictions. Unlike typical LFBO, we also add a penalty for similarity to previously ranked candidates to encourage exploration.
  5. We select the top 10% of them to compile and benchmark. We update the search copies based on the best performing configs and add the latencies to the dataset.

We discuss some key design decisions, which are critical for achieving the improvements in wall-clock time and latency we observed:

Classification vs Regression: Regression-based methods, i.e. training the model to predict latency directly, is the de-facto approach for cost modeling in systems / compiler research. However, we find that a classification-based approach better focuses model capacity on the most performant configs instead of trying to learn the latency of all configs, good or bad. Second, the classification loss enables the model to learn to avoid configs that error out or suffer compile timeouts (as these are assigned negative labels). However, these points do not have any valid latency data for a regression-based approach to learn from.

Encouraging Diversity: Typically configs are compiled in batch, to take advantage of parallelized pre-compilation. The Random Forest classifier may repeatedly select similar configurations that cluster, which can waste the batch budget on redundant samples that provide little new information. To mitigate this, we compute a similarity score based on leaf node co-occurrence from the Random Forest model, and penalize similarity to previously ranked configs.

When we investigate the behavior of LFBO Pattern Search, we see indeed that improvements in performance across kernels, hardware types, and shapes are due its ability to find configurations with better runtime using fewer evaluations. Below is a plot of example auto-tuning traces for a B200 layer-norm kernel, displaying the latency of the best configuration obtained by the autotuner over time. We see not only that LFBO completes auto-tuning earlier (~5 min instead of ~9 min), it finds better configurations faster with much larger jumps in performance compared to Pattern Search.

We see that LFBO accomplishes this by exploring more widely than Pattern Search. Below is a plot of configs sampled by LFBO Pattern Search and Pattern Search for the same B200 layer-norm kernel, where we apply Principal Component Analysis (PCA) for visualization (as configs are high-dimensional). We see that while LFBO Pattern Search evaluates less than half of the number of configs , its sampled configs are more spread out than Pattern Search’s which are highly clumped together due to Pattern Search making only single parameter perturbations. Guided by the classifier, the LFBO Pattern Search is able to make larger, but more targeted jumps.

Finally, we perform an ablation with other surrogate models, in particular regression-based approaches involving a Random Forest, Gradient-Boosting Tree, and Multi-Layer Perceptron (MLP), using a dataset of autotuner logs (collected from PatternSearch). We compute a metric that is most directly correlated with autotuner performance: the expected improvement in kernel latency when using the surrogate to filter the next batch of candidates. Below we plot the expected improvement (in terms of relative % improvements in latency) compared to the percent of candidates the surrogate is allowed to select. We find that the LFBO-based methods deliver the largest expected improvement, with meaningful improvements from diverse selection. Notably, when we only can select 10% of candidates, the regression-based methods perform equivalent or even worse than simple random selection, as regression is not always aligned with ranking performance.

Conclusion

In this blog post, we illustrate how machine learning (ML) can accelerate the autotuning engine and improve the kernel authoring experience in Helion. By using the latency data collected during the search process, we can focus the autotuner on more promising configurations, saving time and discovering faster kernel configs. We are actively interested in applying additional ML techniques to enhance the auto-tuner, including methods from reinforcement learning (RL) and large language models (LLMs), and welcome any contributions.

]]>
PyTorch Foundation Announces New Members as Agentic AI Demand Grows https://pytorch.org/blog/pytorch-foundation-announces-new-members-as-agentic-ai-demand-grows/ Tue, 24 Feb 2026 17:00:01 +0000 https://pytorch.org/?p=47596 PyTorch Foundation New Members Press Release

Foundation welcomes Clockwork.io, Emmi AI, NIPA, Nota AI., Yasp, CommonAI CIC, Carnegie Mellon University, Monash University, and University of Leicester

Summary

  • The PyTorch Foundation announced nine new members, adding five Silver members and four Associate members to its vibrant community since December 2025.
  • Membership growth underscores rising demand for open, community-driven, production-ready AI tooling as agentic AI accelerates, strengthening the ecosystem around PyTorch Foundation projects like PyTorch, vLLM, DeepSpeed, and Ray.
  • New Silver members include Clockwork.io, Emmi AI, National IT Industry Promotion Agency (NIPA), Nota AI., and yasp.ai, and new Associate members include Carnegie Mellon University, CommonAI CIC, Monash University, and University of Leicester.
  • Developers, engineers, and industry leaders alike are encouraged to register for PyTorch Conference Europe in Paris on April 7-8, 2026 to learn more about the growing open source AI ecosystem.

NAPA, Calif. – Linux Foundation Member Summit – Feb. 24, 2026The PyTorch Foundation, a community-driven hub for open source AI under the Linux Foundation, today announced significant expansion of its membership, with nine new members joining since December 2025. New members include Carnegie Mellon University, Clockwork.io, CommonAI CIC, Emmi AI, Monash University, National IT Industry Promotion Agency (NIPA), Nota AI, University of Leicester, and yasp.ai.

New PyTorch Foundation membership signals sustained growth and progress in agentic AI innovation, with the Foundation leading the way on open, community-driven AI. Fueled by industry participation from leading universities, AI startups, global governments, and more, the PyTorch Foundation’s production-ready tools and libraries – including PyTorch, vLLM, DeepSpeed, and Ray – play integral roles in the AI stack.

“There are no agentic systems without the models that power them,” said Mark Collier, GM of AI at the Linux Foundation and Executive Director of the PyTorch Foundation. “From training frameworks like PyTorch and optimization systems like DeepSpeed that create capabilities such as advanced tool calling, to inference engines and orchestration layers like vLLM and Ray that operationalize them, the Foundation hosts critical layers of the open source AI stack. The growth of our membership reflects a shared recognition that these capabilities must be built collaboratively in a vendor-neutral environment.”

Clockwork.io, Emmi AI, NIPA, Nota AI, and yasp join the foundation as Silver members. CommonAI CIC, Carnegie Mellon University, Monash University, and University of Leicester join as Associate members.

This news follows Ray joining the PyTorch Foundation as a foundation-hosted project in October 2025. The open source distributed computing framework for AI workloads, developed by Anyscale, offers development teams a seamless way to execute data processing, forming an integrated open source distributed computing layer for agentic AI alongside vLLM and PyTorch as part of the foundation.

To learn more, join the global PyTorch community in Paris, France from April 7-8, 2026 for the inaugural PyTorch Conference Europe. Register here for early-bird pricing on the latest in open source AI and machine learning.

Supporting Quotes

“At Emmi AI, PyTorch is a key part of how we bring AI into real-world engineering workflows. Becoming a member of the PyTorch Foundation is a natural step for us as we contribute to an open ecosystem that accelerates research, deployment, and impact.”

– Miks Mikelsons, Chief Operating Officer & Co-Founder, Emmi AI

“Open source AI plays a critical role in bringing research innovations into real-world applications. By joining the PyTorch Foundation, we look forward to collaborating with the community and contributing our experience in AI model optimization.”

– Tae-Ho Kim, CTO, Nota AI 

“AI teams shouldn’t have to redesign their models every time the hardware changes. Our work focuses on separating model innovation from infrastructure constraints, so developers can run efficiently anywhere. Becoming part of the PyTorch Foundation aligns with our belief that open ecosystems are essential to reduce friction, avoid lock-in, and scale AI sustainably.”

– Reza Rahimi, CTO, yasp

###

About the PyTorch Foundation

The PyTorch Foundation is a community-driven hub supporting the open source PyTorch framework and a broader portfolio of innovative open source AI projects. Hosted by the Linux Foundation, the PyTorch Foundation provides a vendor-neutral, trusted home for collaboration across the AI lifecycle—from model training and inference, to domain-specific applications. Through open governance, strategic support, and a global contributor community, the PyTorch Foundation empowers developers, researchers, and enterprises to build and deploy AI at scale. Learn more at https://pytorch.org/foundation.

About the Linux Foundation

The Linux Foundation is the world’s leading home for collaboration on open source software, hardware, standards, and data. Linux Foundation projects, including Linux, Kubernetes, Model Context Protocol (MCP), OpenChain, OpenSearch, OpenSSF, OpenStack, PyTorch, Ray, RISC-V, SPDX and Zephyr, provide the foundation for global infrastructure. The Linux Foundation is focused on leveraging best practices and addressing the needs of contributors, users, and solution providers to create sustainable models for open collaboration. For more information, please visit us at linuxfoundation.org.

The Linux Foundation has registered trademarks and uses trademarks. For a list of trademarks of The Linux Foundation, please see its trademark usage page: www.linuxfoundation.org/trademark-usage. Linux is a registered trademark of Linus Torvalds.

Media Contact

Grace Lucier
The Linux Foundation
pr@linuxfoundation.org

]]>
PyTorchCon Europe Schedule is Live https://pytorch.org/blog/pytorchcon-europe-schedule-is-live/ Mon, 23 Feb 2026 17:54:28 +0000 https://pytorch.org/?p=47535 PyTorchCon EU Schedule Live 2026

The schedule for PyTorch Conference Europe is officially live! Join us 7-8 April in Paris for 95+ sessions across 8 tracks over 2 days dedicated to the builders, researchers, and practitioners advancing the PyTorch ecosystem across Europe and beyond. From core framework innovation to applied AI breakthroughs, this is where the community connects around what is working right now and what is coming next.

Featured Session Highlights

View the full schedule.

Join Us

Early bird registration pricing ends this week. Buy your ticket by 27 February to save. What’s Included?

  • All conference content
  • Social events & networking activities
  • Community Expo
  • Lunches, morning & afternoon coffee & snacks
  • Event swag

Student or faculty? Learn more about our discounted academic rate.

Need help getting to the event? Scholarships are available through 11 March. Apply now!

Register now for PyTorchCon EU.

Become a Sponsor

Seize your opportunity to influence the future of Generative AI and Machine Learning by becoming a sponsor! PyTorch is at the forefront of innovation—empowering rapid experimentation, flexible model development, and efficient deployment into production environments with its powerful, versatile ecosystem of tools and thriving community of dedicated users.

As a sponsor, you’ll be in the heart of the vibrant, global AI/ML ecosystem, connecting directly with hundreds of expert attendees, researchers, engineers, and decision-makers who are actively shaping the conversations driving the next generation of advancements.

]]>
Pyrefly Now Type Checks PyTorch https://pytorch.org/blog/pyrefly-now-type-checks-pytorch/ Thu, 12 Feb 2026 20:30:28 +0000 https://pytorch.org/?p=47019 We’re excited to share that PyTorch now leverages Pyrefly to power type checking across our core repository, along with a number of projects in the PyTorch ecosystem: Helion, TorchTitan and Ignite. For a project the size of PyTorch, leveraging typing and type checking has long been essential for ensuring consistency and preventing common bugs that often go unnoticed in dynamic code. Migrating to Pyrefly brings a much needed upgrade to these development workflows, with lightning-fast, standards-compliant type checking and a modern IDE experience. With Pyrefly, our maintainers and contributors can catch bugs earlier, benefit from consistent results between local and CI runs, and take advantage of advanced typing features. In this blog post, we’ll share why we made this transition and highlight the improvements PyTorch has already experienced since adopting Pyrefly.

Why Switch to Pyrefly?

To support the future development of PyTorch, we wanted a type checker that is fast, easy to use, consistent across developer environments, and actively maintained. These factors ultimately influenced the decision to move forward with Pyrefly.

Balancing Speed with Accuracy

In a recent round of benchmarking type checking Pytorch took 50.6 seconds using MyPy, whereas Pyrefly (v44.1) took only 5.5 seconds. This is a significant speed improvement over Pytorch’s existing tooling while still maintaining robust type safety. We wanted an alternative that not only delivered fast results, but would also help our contributors catch bugs early and identify gaps in our type coverage. Pyrefly appears to strike the right balance for us, being fast enough to keep up with our development speed without compromising on the quality of type safety.

That said, we see this as just the beginning; there is still room for Pyrefly to become even faster, and we expect to benefit from even greater speed gains as the tool continues to evolve. We’ll be closely following Pyrefly’s ongoing development and look forward to integrating future performance enhancements as they become available.

Simplified Configuration

Previously, our reliance on MyPy required contributors to juggle multiple configuration files to manage coverage and strictness levels across the codebase. This made it difficult to determine exactly which files were being checked and under what specific rules. Transitioning to Pyrefly has helped address these challenges. With direct support from the Pyrefly team, PyTorch has now transitioned to use a single unified Pyrefly configuration and required suppressions, making it much easier for our maintainers to understand which files are being typechecked and how. 

Consistency across Development Environments

Previously, developers often encountered discrepancies between their IDE, local CLI, and the CI environment because different type-checking engines were being used at each stage. MyPy might be used in PyTorch CI jobs, but when it comes to IDEs, other type checkers were preferred that behaved slightly differently. Or developers would have a different MyPy strictness mode enabled for their CLI runs that differed from what was used in CI. These inconsistencies led to unpredictable feedback loops and a frustrating experience where code that passed their local type checking run would fail in CI. By adopting Pyrefly, which provides a high-quality IDE experience alongside robust CLI and CI functionality, PyTorch developers can now benefit from consistent results across all their development environments.

Before After
CI MyPy (full project run) Pyrefly
CLI MyPy (only on select files) Pyrefly
IDE Pyright OR other Pyrefly

Active Maintenance and Rapid Development

Another major reason for migrating is that Pyrefly is actively maintained and evolving quickly, with significant room for continual performance improvements. We’ve appreciated the responsiveness to user feedback and the rapid development cycles, which include new minor releases every Monday. It’s not uncommon for a bug to be reported and resolved in time for the very next release, ensuring that issues are addressed and new features are delivered promptly. An example of this is described in a recent Pyrefly blog post, where a performance bottleneck was identified and promptly resolved, resulting in an 18x speed up in IDE responsiveness across the PyTorch codebase.

Throughout this migration, and as we continue using Pyrefly, our priority is to avoid regressions in type safety or developer experience. Maintaining a regular line of communication with the Pyrefly team has been essential for quickly addressing edge cases and enabling a smooth transition for our contributors.

Additional Benefits for PyTorch Contributors

PyTorch contributors and maintainers have already experienced meaningful improvements since moving to Pyrefly. Beyond the initial motivations for the transition, other benefits include the following:

Improved code quality

The rollout of Pyrefly has already led to the discovery and resolution of numerous bugs in the PyTorch codebase. One factor that helped achieve this was due to the fact that Pyrefly runs in a consistent mode across Pytorch. Take the code example below: unless MyPy is in strict mode, it doesn’t type check the bodies of untyped functions, meaning errors like this would possibly go unnoticed. Pyrefly, on the other hand, runs in one consistent mode across the codebase and is able to catch these types of errors.

def foo():
    return 1 + "" # pyrefly error

Seamless IDE Experience

Pyrefly integrates natively with many major IDEs, bringing real-time type feedback, hover documentation, and instant diagnostics directly into the editor that match your local and CI results. Now PyTorch contributors using a diverse range of IDEs can spot type errors as they code and be confident their results are consistent, reducing context-switching and making it easier to maintain high code quality. VSCode users can download our IDE extension here. Once enabled, it will automatically find the configuration file in the PyTorch project.

Advanced Typing Capabilities

Pyrefly brings advanced typing features to PyTorch, including robust support for complex typing patterns and strict adherence to Python typing specifications. This empowers contributors to write safer and more expressive code, while maintaining performance and a smooth developer experience.

Pyrefly’s inference capabilities can also enable developers to detect type errors even in code that lacks explicit type annotations. This means that legacy code, experimental modules, and fast-moving prototypes can benefit from increased type safety, without requiring a massive upfront investment in annotation. It can also help identify areas of code that could benefit from more explicit type annotations, helping us move forward with our goals of increasing type coverage in the codebase. Currently, return type inference is not enabled by default in PyTorch, but we are actively working to add annotations and fix type issues in order to un-gate this feature in the near future.

def foo():
    return 1
foo() + "hello" # mypy: no error, # pyrefly: error [unsupported-operation]

Get Started with Pyrefly

Contributors to PyTorch can get started using Pyrefly by installing the extension in their editors, and can start using it for local type checking quickly and easily using lintrunner:

lintrunner init
lintrunner

Contributors to Helion can also get started by installing the IDE extension and can do a local type check by running the repository’s lint.sh file

 ./lint.sh install && ./lint.sh

Pyrefly is also integrated into our CI suite under the lint job to ensure consistency across the codebase. This ensures that the same rules applied during local development are enforced on every PR. When you open a pull request, you can find the Pyrefly results by navigating to the “Checks” tab and selecting the lint job.

If you’re not a PyTorch contributor but still want to check out Pyrefly on your own project, you can get the VSCode extension here or check out the Pyrefly documentation.

Future Work

Switching to Pyrefly marks a practical and meaningful advancement for the PyTorch project. Developers are already seeing the benefits of faster and more consistent type checking, and the initial rollout has helped uncover and resolve a substantial number of bugs. This transition has streamlined workflows and laid the foundation for ongoing improvements in both code quality and developer experience.

Looking ahead, we hope to continue seeing performance improvements from Pyrefly as the tool matures. We’re also excited to partner with the Pyrefly team to further improve typing across the codebase. Strengthening type annotations in one of the most widely used AI/ML libraries will enable maintainers and the broader community to more confidently leverage PyTorch in production environments. Deploying a newer, faster type checker with Pyrefly is only the first step of that journey.

As always, community feedback is invaluable. We encourage PyTorch contributors and users to share their experiences, report issues, and suggest improvements as we continue refining the type checking workflow. If you have questions or wish to provide feedback to the Pyrefly team, you can do so in Discord, or submit bug reports by opening a GitHub issue in the Pyrefly repository.

Finally, we want to extend our sincere thanks to both the PyTorch and Pyrefly teams, as well as the community, for their feedback and testing throughout this transition.

]]>
Why I’m Joining the PyTorch Foundation https://pytorch.org/blog/why-im-joining-the-pytorch-foundation/ Wed, 11 Feb 2026 16:55:38 +0000 https://pytorch.org/?p=46965

I want to start by thanking Matt White for everything he has built over the past two years. The growth of the PyTorch Foundation speaks for itself. What began as a single-project foundation is now a multi-project home for some of the most critical infrastructure in AI. That did not happen by accident. It is the result of real technical leadership, genuine community investment, and a clear belief in open collaboration. Matt is now stepping into the role of Global CTO of AI at the Linux Foundation and will transition to the role of CTO at the PyTorch Foundation, where he will focus on the technical strategy and direction that will define what’s possible next.

I’m thrilled to be joining the PyTorch Foundation as its new Executive Director. Here’s why.

The Most Important Open Source Projects in the World

There is not a more important open source project in the world right now than PyTorch. The daily onslaught of new state-of-the-art models proves it. When you hear about models writing compilers from scratch capable of compiling the Linux kernel, you’re getting a glimpse of the future that PyTorch makes possible.

But here’s what I think people outside of our community are only beginning to understand: the PyTorch Foundation is no longer just about PyTorch.

vLLM has become the inference engine of choice for the industry. When a new model drops, it runs on vLLM on day one, which tells us where the center of gravity lives. Inference is the largest workload in human history, and it runs on a PyTorch Foundation project.

DeepSpeed is pushing the boundaries of training efficiency at a scale that was unthinkable a few years ago. Ray is powering the orchestration and scaling layer that lets AI workloads run across the industry. These are foundational technologies with massive communities of their own, and they chose to make their home here.

Training. Inference. Orchestration. The critical layers of the AI stack live under one roof.

Every Innovation Story Is an Infrastructure Story

I’ve spent my career finding the infrastructure layer of emerging technology waves and building open source ecosystems around them. I co-founded OpenStack in 2010 and built the OpenStack Foundation (now OpenInfra Foundation), spending over a decade helping create the open source cloud. Last year we merged the OpenInfra Foundation with the Linux Foundation, and I became General Manager of AI and Infrastructure and Executive Director of the LF AI and Data Foundation. Now I get to put that experience into action with the PyTorch Foundation.

If there’s one thing I’ve learned across all of that, it’s that every innovation story is an infrastructure story if you know where to look. AI is going to reshape every aspect of the lives of every human being on earth, and it is going to do so at a speed that makes previous technological transitions look slow. The industrial revolution played out over generations. The internet transformed society over decades. AI is compressing that arc into years. The infrastructure that makes all of this possible is being built right now, in the open, by the communities in this foundation

We don’t want any one company or country to dominate such critical technologies. They have to be built together by communities that trust each other enough to do the hard work side by side. The best open source foundations foster the conditions that let communities lead. They keep the path open for the widest possible participation and the largest possible impact. That’s what we need to do again, and I’m here to do that work with all of you. 

The Energy Is Real

I had the opportunity to attend PyTorchCon in San Francisco last October, and I was in awe of the community energy in that place. That’s not easy to pull off in Moscone, and it’s not something you’ll find at just any open source conference. I’ve been to many of them. It reminded me deeply of the early OpenStack days when our summits were doubling every year, and people were genuinely having fun while changing the world.

If you’re part of this community, whether you contribute to PyTorch, vLLM, DeepSpeed, Ray, or the ecosystem around them, you may not fully realize it yet, but that’s exactly what you’re doing. Enjoy the ride.

What Comes Next

My prime directive is clear. Serve the communities that make this foundation what it is. Advocate for the open path that leads to the most innovation, the widest impact, and the largest number of people served by this technology. And make sure that every community that calls this foundation home knows that it belongs here and that its work matters.

If you’re headed to a PyTorch Conference, a PyTorch Day, or anywhere else this community gathers, come find me. I want to meet the people doing this amazing work. The best part of open source has always been the people, and I can’t wait to get to know more of you.

Let’s go build the future.

Mark Collier is Executive Director of the PyTorch Foundation, General Manager of AI and Infrastructure at the Linux Foundation, and Executive Director of the LF AI and Data Foundation. He co-founded the OpenStack project in 2010 and spent 13 years building the OpenStack Foundation and open source cloud community.

]]>
PyTorch Foundation: The Next Chapter, Together https://pytorch.org/blog/pytorch-foundation-the-next-chapter-together/ Wed, 11 Feb 2026 16:55:14 +0000 https://pytorch.org/?p=46960

Over the past nearly two years, I’ve had the privilege of serving as Executive Director of the PyTorch Foundation. As I look back on what we have accomplished together, one thing stands out clearly: our momentum is not accidental. It is the result of a global community of maintainers, contributors, researchers, practitioners, member organizations, and volunteers who have chosen collaboration, openness, and technical rigor as the path to progress.

This post is both a thank you and a transition update, shared first and foremost with the PyTorch community.

What we built in a short time

In a relatively short period, the PyTorch Foundation has evolved from a single-project foundation centered on PyTorch into a multi-project home for critical components across the AI development lifecycle. Today, the Foundation proudly hosts four major projects: PyTorch, vLLM, DeepSpeed, and most recently Ray. Alongside these hosted projects, the broader PyTorch ecosystem has expanded to more than 100 projects, including Unsloth, verl, SGLang, FEAST, and many other high-quality open source efforts that are pushing the state of the art forward.

At the same time, our membership has grown to 33 organizations, nearly doubling, and we updated our membership tiers to better reflect the scale and maturity of our ecosystem. Those member commitments matter, because they translate into real investment in the shared infrastructure and community programs that enable open source AI to thrive.

Stronger governance and deeper technical collaboration

As our technical scope expanded, so did our governance. We launched the initial Technical Advisory Council and supported its growth into a more active forum for cross-project alignment. We also established five core working groups: CI Infrastructure, Multi-Cloud, Ecosystem, Accelerators, and Security.

These groups are where hard, practical problems get solved: keeping CI reliable and scalable, improving portability and cost efficiency, coordinating cross-project priorities, strengthening security posture, and making it easier for developers and organizations to adopt and deploy PyTorch and related projects. The result has been measurably increased technical engagement, clearer project roadmaps, and more consistent collaboration patterns across the Foundation’s hosted projects and the broader ecosystem.

A bigger global footprint, powered by the community

The growth of PyTorch is global, and our community programs have expanded accordingly.

We grew from a conference of roughly 300 attendees to a flagship PyTorch Conference in San Francisco that welcomed more than 3,000 participants. We successfully launched PyTorch Days with events in Paris and Beijing, and we are continuing to expand our global presence. In 2026, we will hold three PyTorch Conferences: Europe in Paris (April), China in Shanghai (September), and our flagship event, North America in San Jose (October). These will be complemented by additional PyTorch Days, starting in Bengaluru this past weekend, with more events in development, including Beijing, Seoul, and others.

We also launched the PyTorch Ambassadors program, now approaching 50 ambassadors, with another cohort planned. This is one of the most important community programs we run, because it scales something no single team can manufacture: local leadership. Ambassadors host meetups, welcome new contributors, and help PyTorch show up meaningfully in regions and communities around the world. In parallel, we’ve been building a speaker bureau to connect domain experts from the community with events seeking credible technical speakers.

Academic outreach, research engagement, and education

Another area of focus has been strengthening ties between research, education, and open source practice.

We kicked off an Academic and OSPO outreach program to engage academic labs and university Open Source Program Offices, with early work involving UC Berkeley, UC Santa Cruz, Stanford, the University of Vermont, and Caltech. The goal is to help students build practical open source skills, create clearer pathways from research to production, and identify emerging open source AI projects that could benefit from Foundation support.

We also increased the Foundation’s participation in major research and practitioner venues, supporting workshops, posters, and talks at MLSys, ICML, NeurIPS, and UC Berkeley’s AgentX program. Across the year, I joined many leaders from the PyTorch community in speaking at more than 100 events worldwide to advocate for PyTorch, the Foundation, and open source AI as a durable strategy for innovation.

Finally, the educational output from the community has been exceptional. In 2025, we published more than 130 pieces of educational content, including tutorials, webinars, and blogs, averaging nearly one substantive item every three days. That pace reflects both the depth of expertise across the community and the rate at which the ecosystem continues to evolve.

We also made meaningful progress toward scalable professional development. At the last PyTorch Conference, we kicked off onsite training for the PyTorch Certified Associate program with strong participation. In the coming months, we expect to publish the corresponding exam and online course, and then begin building the content pathway toward a PyTorch Certified Professional designation. The intent is to support developers who want to demonstrate practical PyTorch fluency, while giving employers a clearer signal for hiring and workforce development.

Infrastructure that scales with the ecosystem

Behind every reliable open source ecosystem is infrastructure that works. Over the past two years, we continued strengthening CI reliability and observability, expanded monitoring and logging, and progressed the migration of our download site to the Cloudflare CDN.

Just as importantly, the Foundation’s CI would not be sustainable without the support of member organizations and partners who contribute engineering effort, hardware, and operational expertise. Contributions, current and in progress, from Meta, AWS, AMD, Intel, Microsoft, and NVIDIA have been critical. We have also advanced a multi-cloud strategy so we can diversify our footprint across hyperscalers and neo-clouds, manage cost, and maintain the performance and scale that developers and production users depend on.

What comes next

Even with this progress, the next phase demands more. Key priorities ahead include:

  • Expanding the hosted project portfolio, including adjacent domains such as agentic AI, environments, and reinforcement learning
  • Further diversifying and optimizing CI architecture and costs
  • Onboarding additional project CI workloads where shared accelerator access unlocks faster iteration
  • Expanding training and certification into a durable revenue stream that strengthens Foundation sustainability
  • Deepening community programs, including initiatives such as mentorship and stronger global enablement

As the scope grows, there is a straightforward operational reality: leadership capacity must scale so that organizational throughput, not leadership bandwidth, sets our pace.

A leadership transition to support the next stage

To support this next stage, I’m sharing a leadership transition that takes effect immediately.

I will be stepping into the role of Chief Technology Officer for the PyTorch Foundation, alongside my new role as Global CTO of AI at the Linux Foundation. At the same time, Mark Collier will join the PyTorch Foundation as our new Executive Director.

Mark brings deep experience building and scaling open infrastructure ecosystems, including founding OpenStack and the OpenInfra Foundation. As Executive Director, he will lead the operational and business execution of the Foundation, working closely with the Governing Board. His responsibilities include oversight of Foundation committees (including Finance and Marketing), community programs such as Ambassadors, Foundation-led events, staff management, finances, and membership development. Ultimately, he will be accountable for the overall direction and operations of the Foundation in partnership with the Governing Board.

As CTO, I will focus on technical strategy and execution across the Foundation: supporting the TAC and working groups; advancing our hosted projects and ecosystem alignment; strengthening CI and multi-cloud infrastructure; and driving technical programs, including Academic and OSPO outreach and PyTorch Certified. This structure is intended to increase clarity, accountability, and speed, while preserving community-led technical governance.

Quotes

“It’s great to see the PyTorch Foundation enter a new phase, just months after it evolved into an umbrella foundation. With Mark as the Executive Director and Matt as the CTO, the foundation acquires the level of maturity required by its ambitions. I can’t wait to help build the future of PyTorch with the new leadership and the rest of the TAC.”

– Luca Antiga, CTO, Lightning AI and Chair, PyTorch Foundation Technical Advisory Council (TAC)

“Watching the PyTorch Foundation grow into an umbrella ecosystem has been inspiring—it’s set PyTorch up not only for the short term, but for a long arc of impact foundational to AI. Congrats to Matt on an incredible chapter, and a warm welcome to Mark. I’m excited for where we take PyTorch next!” 

– Joe Spisak, Product Director, Meta Superintelligence Labs & PyTorch Core Maintainer

“The growth of the PyTorch Foundation speaks for itself. Thanks to Matt White for everything he has built. What began as a single-project foundation is now a multi-project home for some of the most critical infrastructure in AI. That did not happen by accident. It is the result of real technical leadership, genuine community investment, and a clear belief in open collaboration. I’m excited to keep that momentum going that will define what’s possible next.”

– Mark Collier, Executive Director, PyTorch Foundation

Thank you

I want to close with an explicit note of appreciation. The PyTorch Foundation’s progress is not the product of any single organization or individual. It is the result of thousands of community members: maintainers, contributors, reviewers, working group participants, event organizers, speakers, educators, and member company teams who consistently choose collaboration over fragmentation and long-term stewardship over short-term advantage.

Thank you for the trust, the effort, and the standards you bring to this community.

I’m excited for what comes next, and I’m particularly looking forward to working with Mark as he steps into the Executive Director role. Please join me in welcoming him and supporting him as he begins this next chapter with us.

We have built something strong. Now we scale it.

Matt White

CTO, PyTorch Foundation
Global CTO of AI, Linux Foundation

]]>
Accelerating Mamba2 with Kernel Fusion https://pytorch.org/blog/accelerating-mamba2-with-kernel-fusion/ Fri, 06 Feb 2026 22:48:54 +0000 https://pytorch.org/?p=46746 Summary

In this post, we discuss how we optimized the Mamba-2 State-Space Dual (SSD) module with a fused Triton kernel that yields speedups of 1.50x-2.51x on NVIDIA A100 and H100 GPUs. To achieve this, we fused all five SSD kernels into a single Triton kernel with careful synchronization. To our knowledge, this is the first end-to-end Triton fusion of all five SSD kernels. This reduces launch overhead and avoids redundant memory operations, making the kernel faster across all input sizes. The rest of this blog will cover how we fused the SSD kernels, what bottlenecks remain, benchmark results, and our plans to release the kernel in the open source so the community can benefit.

Figure 1. Fused SSD Triton Kernel A100 and H100 Speedups

Background

Mamba-2 is a sequence model based on the state-space duality (SSD) framework, which connects structured state-space models (SSMs) with attention-based transformers as an optimized successor to the original Mamba model. One key advantage of Mamba-style models is scalability to long sequences. Mamba’s state-space mechanism scales linearly with context length. In practice, doubling the input sequence length roughly doubles Mamba’s compute and memory needs, whereas self-attention would quadruple them. This makes Mamba-2 especially attractive for extremely long contexts, such as 128K tokens and beyond.

IBM’s Granite 4.0 model family recently adopted a hybrid architecture that combines Mamba-2 blocks with transformer blocks. In Granite 4.0, nine Mamba-2 layers are used for every one attention layer to handle long-range context efficiently. With Mamba-2 becoming integral to such models, optimizing Mamba-2’s performance is critical for faster inference. The core of Mamba-2’s computation is the SSD module, which replaces the attention mechanism in each layer. The original Mamba2 SSD implementation is mostly bottlenecked by memory bandwidth and latency and includes writing and reading intermediate data, so there are opportunities for improvement. In this blog, we focus on accelerating this SSD prefill operation with an optimized fused kernel.

Mamba2 Operations

The operations that make up a typical Mamba2 block are listed in Table 1. We focused on fusing the five SSD kernels because they behave as one conceptual SSD operation, though further fusion (e.g., convolution and layernorm) may be possible as discussed later.

 

Layernorm Helps with numerical stability
In Projection Projects input to SSD channels/size
Depthwise Convolution Mixes the last few tokens
SSD Chunk Cumsum Computes the dt per token and cumulative decay within a chunk
SSD Chunk State Computes the state at the end of this chunk in isolation
SSD State Passing Computes the global states at the end of each chunk
SSD BMM Computes how the each chunk of input x affects the corresponding chunk of output y
SSD Chunk Scan Computes each chunk of y from the corresponding chunk of x and previous chunk’s global state
Layernorm Helps with numerical stability
Out Projection Projects output to the model’s hidden dim

Table 1. Mamba2 operations

Why Do We Need Kernel Fusion?

During prefill, which is the forward pass over the prompt or input sequence before token generation, Mamba-2’s SSD module executes as a pipeline of five GPU kernels. In the original implementation, these five kernels run sequentially on the GPU.
However, launching multiple small kernels in sequence incurs significant overhead and prevents the GPU from reusing data between stages efficiently. By applying kernel fusion we can get several key benefits:

  • Eliminating Kernel Launch Overheads: One launch instead of five reduces CPU-GPU synchronization and scheduling delays.
  • Improving Cache Locality: Data produced in one stage is immediately consumed by the next within the same threadblock, increasing cache hits and reducing global memory traffic.
  • Overlapping Computation: Different parts of the fused kernel can execute in parallel (where independent), better utilizing GPU resources.

Our solution fuses all five kernels into a single Triton kernel, so that the entire SSD prefill computation for a layer happens within one GPU launch.

Efficient Kernel Fusion Technique

Unlike a simple matmul + activation fusion, SSD fusion is complex because the computation spans multiple steps with complicated dependencies. The original implementation relied on implicit synchronization across kernels, which disappears when we fuse everything. In this section, we discuss why that matters and our approach to making fusion work in practice.

The five steps of the Mamba2 SSD were originally implemented as five separate kernels: Chunk Cumsum, BMM, Chunk State, State Passing, and Chunk Scan, which operate on fixed-size chunks of tokens. The figure below illustrates the dependencies between these kernels.

Figure 2. Mamba2 SSD Prefill Kernel Graph

The State Passing step has dependencies between chunks, and the original State Passing kernel handled this by looping over chunks within threadblocks and splitting the state’s channels across threadblocks for parallelism. With this State Passing loop and the implicit global synchronization between kernel launches, all dependencies were handled in the original kernels.

The real technical challenge comes when we try to fuse all five kernels into a single launch. Once fused, we lose the implicit global synchronization that the original kernels relied on, so we must explicitly manage both within-chunk and across-chunk dependencies. Most of the dependencies are between different steps but the same chunk, so for the three largest kernels, Chunk State, State Passing, and Chunk Scan, these intra-chunk dependencies could be handled by running all steps of a particular chunk on the same threadblock. This would also give us the ability to keep intermediate data between steps in registers or L1 cache (private to each SM) since the data will be used on the same threadblock.

However, this approach is neither possible nor correct. The original State Passing kernel has the aforementioned loop, which makes its threadblock grid not match the original Chunk State and Chunk Scan kernels. Furthermore, having separate threadblocks for each chunk would remove the natural synchronization and correctness provided by looping over chunks within a single threadblock.

To make fusion possible, we split the iterations of the State Passing loop across chunks into separate threadblocks so the threadblock grids match. We get correctness by ordering these threadblocks with atomics, a form of serialization that looks quite inefficient on the surface but can be mitigated by overlapping with the other two parts.

For example, if we ran 8 chunks in parallel, we would expect a ~8x local slowdown from the State Passing serialization. However, the fused State Passing is a small fraction of the three large steps, especially since it no longer has to read the state from global memory (it’s already in the threadblock from the fused Chunk State).

By Amdahl’s law, we would expect the runtime to change to (State Passing fraction) * 8 + (1 – State Passing fraction) * 1. For example, if the State Passing step was only 1/7th of the combined time excluding synchronization, we would get (1/7) * 8 + (6/7) * 1 = 2, implying a 2x overall slowdown. However, this does not account for overlap. Since the synchronization of State Passing can overlap with the Chunk State and Chunk Scan computation, the slowdown would be roughly: 

State Passing compute time + max(other compute time, State Passing synchronization time)

= 1/7 + max(6/7, 1/7 * 7) = 1.14x

If State Passing was a smaller fraction of the total runtime or if less chunks are processed concurrently, we could theoretically avoid any serialization slowdown in all but the first chunks.

Figure 3. State Passing Overhead Overlap

Figure 3 shows the theoretical synchronization delays, which are high for the first chunks run in parallel, but settle down to a low overhead in all later chunks. We can see that although chunk 8 depends on chunk 7, it only has to busy-wait 1 unit of time instead of 8 since the chunk 0 Chunk Scan and chunk 8 Chunk State overlap with the State Passing of chunks 1-6. In practice, NVIDIA Nsight Compute benchmarks show that fewer than 3% of warp stalls (idle thread time) are caused by the State Passing synchronization, implying that the serialization latency is hidden.

The BMM and Chunk Cumsum steps are extremely fast compared to the other three. BMM splits work along ngroups instead of nheads, and Chunk Cumsum has its threadblocks handle multiple heads for efficiency. For simplicity, we launch separate threadblocks for these two steps (the first few threadblocks work on them) and have the threadblocks for the other three steps await their BMM and Chunk Cumsum dependencies with atomics.

When a threadblock begins executing the kernel, it is assigned to work on the Chunk Cumsum step unless all Chunk Cumsum work has already been assigned. Similarly, if there is no unassigned Chunk Cumsum work, the threadblock would be assigned to the BMM step if available. After both of these fast steps have been fully assigned to threadblocks, later threadblocks each start processing a chunk in Chunk State, process that same chunk in State Passing, and finally output that chunk after Chunk Scan.

While kernel fusion improves data reuse and speeds up the SSD, additional optimizations are necessary to achieve maximum performance. These include reordering threadblocks to hide serialization latency, adding cache hints to loads/stores to prioritize reused data, separating special cases outside of the fused kernel to reduce register pressure, changing some intermediate datatypes, tuning the chunk size, and restructuring operations for less latency. These optimization techniques are described in more detail in Appendix A.

Remaining Bottlenecks

In this section, we analyze the bottlenecks in the optimized fused SSD kernel using Nsight Compute to examine the final utilization, stall patterns, and resource tradeoffs.

At a high level, we can look at the compute and memory utilization of the fused kernel to get an idea of what limits this kernel.

Figure 4. A100 Nsight Compute Summary

Figure 5. H100 Nsight Compute Summary

We can see that overall fused SSD compute utilization is about 40-50% and memory utilization is about 65-75%. It is not possible to achieve 100% utilization due to the initial load/store latency and other overheads, but it’s usually possible to get at least 80% in a well-optimized kernel. For context, the H100 and A100 matmuls used in Mamba2 get 85-96% compute utilization. Since neither compute nor memory has good utilization in the SSD kernel, the bottlenecks are more complicated than just memory bandwidth or compute throughput.

We can look at the warp state statistics to see what warps are stalled on. “Selected” means that the warp executed a new instruction, but “Stall Long Scoreboard” and “Stall Barrier” indicate that warps are idle waiting for L2/VRAM or synchronizing.

Figure 6. Warp State Statistics for the fused SSD kernel on an H100

There are a few ways to reduce the effect of these stalls and improve the compute or memory utilization:

  1. Increase occupancy
  2. Increase instruction-level parallelism
  3. Optimize the code to use less synchronization and memory ops or cache data better

Occupancy

Modern NVIDIA GPUs have 12-16 warps (groups of 32 threads) per warp scheduler, and each of these warp schedulers can issue a new instruction every cycle. If we only have 1 warp in each scheduler, we waste cycles every time that the warp stalls. However, if we have 16 warps in each scheduler, each warp could be stalled about 15/16 of the time without leaving the hardware idle. Occupancy is the fraction of available warp slots that are actually filled with active warps. Increasing occupancy helps hide memory and instruction latency, increasing GPU utilization.

Figure 7. Occupancy for the fused SSD kernel on an H100

This fused kernel only gets 25% occupancy in the current config, limited by registers and shared memory. Although we can increase the number of warps and reduce the registers per thread to increase occupancy, this reduces performance in practice, likely due to increased synchronization costs and higher register pressure.

Instruction-Level Parallelism

Instruction-Level Parallelism means designing/optimizing the code to have less immediate dependencies between instructions, allowing the warp to run future instructions even when the previous instructions haven’t finished. This provides the same latency-hiding benefit as increased occupancy, but without requiring more warps.

Reducing Synchronization and Data Transfer

Since the warps are usually waiting on loading/storing memory or a barrier, we can improve performance by reducing the amount of barriers or reducing total data transfer through better caching or different block sizes.

Unfortunately, these three optimization techniques can directly clash and introduce tradeoffs. Each SM in the GPU has limited registers and shared memory, so if each threadblock uses too much, occupancy drops. We can increase instruction-level parallelism by loading data in stages, but that requires more registers and shared memory, resulting in lower occupancy. We can also change block sizes to reduce the total data transferred or increase the cache hit rates, but this also requires more resources and reduces occupancy.

This is why the fused kernel does not have very high memory or compute utilization.

Memory Utilization Details

Figure 8. Memory Chart for the fused SSD kernel on an H100

We can see from this chart that the reported 65–75% memory utilization is mostly from reads through the L2 cache. These reads likely include (i) tensors that fit in L2, (ii) tensors reused across multiple threadblocks, (iii) state transfers between threadblocks, and (iv) VRAM reads that naturally pass through L2. Since L1 caches are private to each SM and not coherent across threadblocks, shifting this traffic to L1 is not feasible. Similarly, bypassing L2 for VRAM traffic would not help, as all global memory accesses pass through L2. 

This memory chart suggests that, apart from the suboptimal memory utilization, the kernel is effectively L2-bound rather than DRAM-bound. Further optimization would therefore require either (1) increasing memory utilization, (2) tuning the block sizes / config, or (3) making radical algorithmic changes.

Line-by-Line Stalls

Nsight Compute profiling shows warp stalls line-by-line, helping us check that the warp stalls are for legitimate reasons. As expected, most warp stalls in the fused kernel are from loading data, synchronization, and computation, with only minor overheads from atomics and inter-chunk synchronization. See Appendix B for more details.

Benchmarks

We benchmarked our Triton kernel on typical inference scenarios, batch size 1-32, sequence lengths from 1K up to 256K tokens, and fp16 states. These graphs highlight the speedup of our kernel over the baseline unfused kernels.

Figure 9. NVIDIA A100 Fused Kernel Speedup Graph

Figure 10. NVIDIA H100 Fused Kernel Speedup Graph

The fused SSD kernel is 1.50x-2.51x faster than the unfused implementation on the SSD portion. At low sequence lengths (especially with batch=1), overheads from kernel launches help the fused kernel, but these constant costs become amortized for longer sequences. At higher sequences, the fused kernel’s lower data movement is even more beneficial as cache thrashing increases. The SSD speedup translates to roughly a 8-13% end-to-end speedup for a model like Mamba-2 2.7B with batch=1 and seq=128K on NVIDIA A100 and H100 GPUs. At shorter sequence lengths, the end-to-end speedup can reach ~20% at 1K context, likely due to the reduced kernel launch overhead.

Accuracy and Correctness

The fused kernel is generally accurate and correct, but there are slight differences in output between the fused kernel and reference solution. These differences depend on the GPU it’s running on and the precisions of some computations. The fused kernel internally uses fp16 for some computations that the original kernels used fp32 for, because this gives a ~16% speedup. Furthermore, the original kernels support either fp32 or fp16 states, but our reported speedups are for fp16 states. The fused kernel still supports the same intermediate datatypes and fp32 states. In this section we explain the tradeoffs in accuracy and performance for these different dtype configs.

In Table 2, we report the accuracy of the output y tensor as percentage of elements that match the original kernels’ output. We test with no threshold (element must exactly match), a small threshold of 1e-3 absolute and relative tolerance, and a medium threshold of 1e-2. In this table, “exact dtypes” refers to using the same dtypes as the original kernel for all calculations, while “relaxed dtypes” refers to using fp16 for a few calculations. Both the fused and original kernels were run with the same state dtype in each column.

fp32 states

exact dtypes

fp16 states

exact dtypes

fp32 states

relaxed dtypes

fp16 states

relaxed dtypes

Match @ atol,rtol=0 99.696% 99.337% 67.307% 66.823%
Match @ atol,rtol=1e-3 100.000% 100.000% 99.819% 99.743%
Match @ atol,rtol=1e-2 100.000% 100.000% 100.000% 100.000%

Table 2. H100 Accuracy Table

Floating point addition is not perfectly associative, so we cannot expect all elements of the output tensor to match with 0 threshold. Even a different Triton launch config can cause very small differences in outputs from the same kernel. For “exact dtypes” (both fp16 and fp32 states), the output is identical for all practical purposes, so this kernel should work with “exact dtypes” even in the most accuracy-sensitive models. For “relaxed dtypes” (which we use in our speedup graphs), we can see that around 1/3 of the elements do not perfectly match the output of the original kernel. However, over 99.7% of the output elements match if we allow the tight threshold of 1e-3. Furthermore, at the commonly-used tolerance of atol=1e-2, rtol=1e-2 (1%), all configurations achieve >99.9995% accuracy, effectively 100%. For practical purposes, we expect the “relaxed dtypes” to have indistinguishable accuracy.

Figure 11. H100 fp32 vs fp16 Accuracy Graph

In Figure 11, we show how our speedup changes when states are in fp32 instead of fp16. Both the fused and original kernels are faster with chunk_size=256 when states are in fp32. This represents a tradeoff of higher compute in return for a smaller state tensor. The fused kernel’s speedup is less for fp32 states than fp16 states, likely because of the different balance of compute and data movement.

Other Architectures

The fused SSD kernel is not limited to Mamba-2. It also applies directly to linear attention, since the SSD formula reduces to the linear attention update when A = 1. In this special case, the fused kernel could be further simplified and optimized for even better performance.

New GPU Features

The fused SSD kernel does not currently use newer GPU features such as the Tensor Memory Accelerator (TMA) and thread block clusters on Hopper GPUs, or the Tensor Memory in Blackwell GPUs. These features can greatly reduce register pressure, which would speed up the SSD and could result in faster Triton configs being possible (e.g., larger block sizes). The thread block clusters could especially be useful for broadcast-loading C, B, and CB matrices that are shared across a group of heads in the SSD kernel. This could give further speedups on new GPUs if necessary.

Further Fusion: Convolution and Layernorm

In this fused SSD kernel, we fused the 5 original SSD kernels. However, the convolution before the SSD and layernorm after the SSD are appealing candidates for fusion because fusing each would remove an entire read and write between kernels. Since the convolution is depth-wise (no channel mixing), the SSD could load d_conv extra along the seqlen dimension and load the conv weights to perform the convolution in registers or shared memory.

We have done some experiments with fusing the layernorm, but with limited benefit. There are two methods to fuse this layernorm:

  1. Launch layernorm threadblocks separately. These threadblocks can wait until the corresponding SSD threadblocks have finished and then read the output y from L2 cache instead of VRAM.
  2. Sync SSD threadblocks across heads, exchange norm values, and compute the layernorm in registers or shared memory.

Method 2 was very slow because the SSD threadblocks stalled while syncing and had no other work to do while waiting. Method 1 worked, but reading from L2 instead of VRAM doesn’t provide as much benefit as registers/shared memory. So far, the speedup has been far below the theoretical limit, and it’s unclear whether further optimizations would make it worthwhile given the added complexity.

Insights on Model Design

With the optimized fusion of the five SSD kernels, Mamba2 prefill is now even cheaper than before. This shifts the runtime-accuracy tradeoff for Mamba2 layers, which could make scaling up both the size and the number of Mamba2 layers the optimal balance in new LLMs. More design insights include:

  • Compute Intensity: The current fused kernel has low compute utilization at the fastest chunk size, so we might be able to afford slightly more complicated operations. Although we could increase compute intensity by increasing the chunk size, that also increases the required registers and other resources, causing an overall slowdown.
  • State Precision: In both the fused and original kernels, the State Passing step must be serial instead of parallel. Although sublinear latency parallel scan algorithms exist, in practice, they can be much slower than the serialized version used in Mamba2. Therefore, minimizing the latency of the State Passing computation as a fraction of the total latency is vital to hiding the serialization latency. If the states can be held in low precisions, such as fp16, this significantly helps the fused kernel. Without a fast State Passing step, we might need to split threadblocks more along other dimensions such as headdim, which would slow down the fused kernel overall.
  • VRAM vs L2 tradeoff: Since the fused kernel has higher L2 bandwidth utilization than VRAM bandwidth utilization, the cost of sharing less data across threadblocks is less. If an architecture’s performance benefits greatly from smaller groups, the added VRAM reads could have less of a negative impact on performance than it had with the original kernels. On the other hand, new GPU features such as TMA multicast loads could reduce the L2 bandwidth utilization, speeding up the SSD and reducing this imbalance.

vLLM Integration

In order to support variable length sequences with initial states but without padding, vLLM introduces the idea of “pseudo chunks”. Any chunk with tokens for multiple sequences in it has multiple pseudo chunks, one for each sequence in that chunk. Most of the 5 kernels function the same, with State Passing loading initial states when a new sequence starts. However, Chunk Scan has a larger threadblock grid that goes over pseudo chunks instead of chunks. In order to support this in the fused kernel, we have a for loop to process all pseudo chunks in the current chunk. The vLLM Chunk Scan offset its reads and writes based on where the pseudo chunk starts in the real chunk. We use masking based on the sequence index instead, since masking provides a speedup. Both offsetting and masking read/write the same amount of data at runtime, but the masking might be more predictable for the compiler, better aligned, or just simpler. The vLLM fused kernel is still being integrated, but it shows similar speedup.

Conclusion

In summary, we fused the five Triton kernels of the Mamba-2 SSD prefill into one, yielding a 2x speedup for the SSD itself, which translates into a ~8–20% end-to-end inference speedup. This significantly boosts throughput for models using Mamba-2 layers. We are excited to integrate these kernel improvements into open-source projects so that the community can easily leverage faster inference with Mamba-2 based models. Stay tuned for updates as this fused SSD kernel lands in the Mamba codebase and in inference frameworks like vLLM.

Appendix A: Optimization Details

Threadblock Order

The State Passing step causes serialization. For a given head, all but one threadblock stall waiting for the previous chunk to be ready. When our GPU runs about 256-1024 threadblocks concurrently but only one makes progress, we get a significant slowdown. Some of the serialization is hidden by the latency of the Chunk State step since later chunks could still be computing Chunk State rather than being stalled in State Passing, but this is not enough. We have both the nheads and batch dimensions that represent domain parallelism (independent work) in the SSD. Instead of launching threadblocks for a particular batch and head before moving on to the next, we can launch threadblocks for multiple (batch, head) combinations. If we launch n different (batch, head) combinations for the same chunk before moving on to the next chunk, our serialization drops by a factor of n (instead of only 1 threadblock making progress, n threadblocks make progress). This n must be carefully balanced, because if it’s too large, we lose L2 cache locality for passing states, and if it’s too small, threadblocks stall. As a simple heuristic, we launch threadblocks for all nheads before moving on to the next chunk, but finish all chunks before progressing in the batch dimension. For models with much more or less heads or significantly different dimensions, a more complicated threadblock order could involve explicitly combining nheads and batch and then splitting it into an inner and outer dimension, with the inner dimension launching before the next chunk.

Cache Hints

The input and output tensors of operations such as the Mamba2 SSD are typically too large to fit in cache. For example, the input and output for 16k context in a Mamba2 SSD with 128 heads of 64 dim each in fp16 will each consume 16k * 128 * 64 * 2B = 256 MiB. Typical GPU L2 caches are 40-50 MiB. Therefore, some data will be evicted from the L2 cache during that kernel. 

Since most of the output tensor does not fit in the L2 cache, it’s not worth using L2 cache capacity for the output to try to speed up the next operation. We can use a cache hint to indicate that the output tensor has the lowest priority for caches. In general, once we access data for the final time in the kernel, we can mark it as low priority for caches. For often reused data, such as CB (which is shared among heads in a group), we can use a high priority cache hint to reduce the chance of eviction.

We can also avoid flushing L1 cache during some sync atomics by specifying “release” semantics. This tells the compiler that previously written data must be globally visible before the atomic operation (e.g. if we are setting a “ready” flag), but this thread does not need to invalidate any caches.

Conditional Separation

In the State Passing step, we have two special cases: reading the initial state instead of the previous chunk’s global state and writing to the final state instead of to the global states tensor. Although conceptually these special cases should only involve swapping the base pointer to read/write to, the initial and final state conditionals increase register pressure and slow down the fused kernel. To solve this, we can handle the special cases outside of the fused SSD kernel. If we replace the nchunks dimension in our state tensor with nchunks + 1, we can copy the initial states into the 0th chunk and copy out final states from the last chunk. These copies are done using the pytorch sliced assignment syntax, which results in small kernels with negligible runtime or launch overhead.

Intermediate Datatypes

For some computations, such as applying the A decay to B in Chunk Scan, we can use fp16 for the computation instead of fp32. This also swaps upcasting B and downcasting the result with only downcasting the scale, reducing casting instructions.

Compile-Time Masks

Triton requires that the dimensions of blocks of tensors in a threadblock are powers of 2 known at compile time. This forces all stores and loads to operate on power-of-2 blocks that might not divide the target tensor exactly. We therefore use masks to cover the entire tensor but avoid reading or writing out of bounds data (or the next block of data). These masks are the same dimensions as the tensor block. However, these masks are not always necessary because model dimensions like headdim are often divisible by the block size and do not change between different inputs. Triton supports tl.constexpr compile-time parameters and setting them based on other parameters with @triton.heuristics. Therefore, we can automatically enable or disable the headdim dimension of the mask at runtime based on if the headdim is divisible by the block size. Although this occurs at “runtime”, it really only occurs once during the initial JIT compilation of the kernel for this model.

Chunk Size

The Mamba2 SSD algorithm takes asymptotically constant computation per token (computation scales linearly with sequence length), but it has a base case of some chunk size that is computed quadratically. Between chunks, the linear algorithm is used, but within a chunk, the quadratic algorithm is used. For more details, see https://tridao.me/blog/2024/mamba2-part1-model/#state-space-duality.

The optimal chunk size represents a tradeoff of higher computation and resources required vs higher hardware utilization and less intermediate states. With the original unfused kernels, the optimal chunk size for Mamba2-2.7B had been 256. However, with the new fused kernel, the optimal chunk size is now 128 for the same model. This smaller chunk size also has the added benefit of reducing register pressure, making the kernel less sensitive to small changes like enabling masks or using higher precision for intermediate results.

Currently, the convention for Mamba2 models is to specify the chunk size in the model’s config. However, since the optimal chunk size varies depending on the original vs fused kernels, it could be better to use a heuristic or autotune the chunk size. This might not be straightforward since the code surrounding the SSD kernels might assume a particular chunk size.

Scale Multiplication Operand

For Chunk State, we can equivalently apply the A decay to X instead of B, since the dimension to be scaled is the inner dimension of the matmul of X and B. Essentially, we do (X * A[None, :]) @ B instead of (X @ (A[:, None] * B). This is faster, probably due to a more similar layout causing less register data movement. For example, due to the required Tensor Core data layout, each thread might already have the required A values to multiply with its X values, but to scale B, we might have to load in a different layout and shuffle data back to the required Tensor Core layout.

Appendix B: Summary of Stall Reasons

If we look at the source in NVIDIA Nsight Compute, we can see the warp stalls for each line of code and assembly instruction in the fused kernel on an H100. Assuming that the kernel and block sizes are optimal, warp stalls can reveal potential areas for optimization.

  1. In order to ensure correctness, we use an atomic add to get threadblock ids in increasing order. This accounts for about 3% of the total warp stalls.
  2. Both the Chunk Cumsum and BMM parts of the fused kernel are very fast, so they only cause less than 2% of warp stalls each.
  3. Atomically checking that the Chunk Cumsum and BMM threadblocks have prepared data for this Chunk State threadblock accounts for about 1.5% of warp stalls.
  4. Chunk State has about 12% of total warp stalls in loading dA, X, and especially B. It also has about 7% stalls in barriers related to scaling and using Tensor Cores.
  5. Despite being serialized along chunks, State Passing has less than 3% stalls on synchronization (including awaiting the previous chunk). Loading the previous states does not cause significant stalling, but updating the state and storing cause about 6% stalls awaiting shared memory or a barrier.
  6. For the previous state’s contribution in Chunk Scan, loading C is about 5% loading stalls, prev_states is about 3% barrier stalls, and the computation is about 8% barrier, loading (for scale), and instruction dependency stalls.
  7. The current chunk’s contribution in Chunk Scan has about 13% stalls in loading data and 18% stalls in computation (including scaling).
  8. The residual (scaled by D) accounts for about 6% of total stalls for loading, shared memory, and computation.

Overall, these stalls are for legitimate reasons and are not easy to optimize away.

]]>
Some Matrix Multiplication Engines Are Not As Accurate As We Thought https://pytorch.org/blog/some-matrix-multiplication-engines-are-not-as-accurate-as-we-thought/ Fri, 06 Feb 2026 22:15:55 +0000 https://pytorch.org/?p=46712 What is an accumulator in an accelerator’s GEMM engine and why does it matter?

GPUs and custom accelerators include specialized compute engines for matrix multiplication (also known as matmul or GEMM), such as NVIDIA’s Tensor Cores. These engines efficiently perform matmul on small tensor blocks; therefore, compilers or libraries typically divide large matmul problems into many smaller ones and feed them to these engines. Usually, the output from a Tensor Core of FP8 (e4m3) matmul with the shape of (block_size_m, block_size_k) and (block_size_k, block_size_n) is a (block_size_m, block_size_n) tensor in FP32 (e8m23). However, one interesting thing users rarely noticed is that for hardware efficiency reasons, this FP32 output could have fewer than 23 effective mantissa bits. In other words, the precision of this Tensor Core operation is lower than FP32 as it appears.  This hardware design choice has been reported to impact model accuracy under certain circumstances 1, 2.  Therefore, from a GPU user’s perspective, we would like to verify the hardware design in use. Because even though the existing hardware cannot be changed, custom kernels can still be written in a proper way to preserve highest achievable accuracy when needed. For hardware designers, it is equally important to have a convenient and efficient way to quantify the impact of this design choice.

Before we dive into details, we need to understand the role of an “accumulator” and the reason for employing reduced precision. Let’s first consider a hypothetical compute engine that can handle a FP8 matmul of block sizes (3, 4) and (4, 3), as illustrated in Fig. 1a. Zooming into the compute engine, the most basic operation would be a row-column inner product, i.e.

cᵢⱼ = ∑ₖ aᵢₖ * bₖⱼ. One can imagine that an efficient hardware design will simply implement 4 multipliers to compute each pair of aik, bkj, followed by 3 adders to sum up the intermediate results, as shown in Fig. 1b. In this simple example, we can see that the multiplication part can be done in one single parallelized “compute step” assuming enough multipliers are available. But the addition part requires 2 compute steps to complete, as it needs to be done in a hierarchical, serial way. If we scale up this unit design for N elements, multiplication will still take only one step while addition will take log(N) steps.

Furthermore, each multiplier only needs to compute FP8 * FP8 (e4m3), which involves a 4-bit + 4-bit addition (for exponent) and a 4-bit x 4-bit multiplication (for mantissa). However, since each partial product needs to be aligned correctly, the subsequent adders must use significantly more bits than the multipliers. As illustrated by Fig. 2 (just an example, not a real FP8 case), adding two limited precision FP numbers with only 4 mantissa bits could end up as a FP number that requires much more mantissa bits.  This loosely explains why the circuit complexity and cost (silicon area and power) of a floating point multiply-accumulate (MAC) operation has a strong dependency on the accumulation precision. Therefore, even if it is safer to use FP32 as the accumulation precision (Fig. 2b), it is worthwhile to explore opportunities to use reduced accumulation precision. 

With these examples in mind, the benefits of using reduced‑precision adders in matmul engines become clear.

How to Verify Accumulator Precision? (Using TensorCore as an Example)

Given that a matmul accumulator could be designed with fewer than 23 mantissa bits, the actual output is effectively e8mNacc (where Nacc < 23) with trailing 0s padded up to e8m23. In other words, the output of FP8 TensorCore may look like FP32, but anything smaller than e8mNacc were never calculated during the computation. In this blog, we will demonstrate a simple approach to investigate the accumulator precision using triton kernel.

Assuming the TensorCore output has only Nacc effective mantissa bits (as in e8mNacc), i.e., the last 23 − Nacc bits are 0 already, if we apply a mask to truncate the last Ntrun bits of the TensorCore output, as long as Ntrun ≤ 23 − Nacc, the final matmul results should remain unchanged. Furthermore, by sweeping Ntrun and comparing the matmul output to a reference (i.e., Ntrun = 0), we can infer the accumulator precision of the FP matmul unit under investigation. Here, “truncation of Ntrun bits” refers to zeroing out the last Ntrun bits of a floating point number, which are the least-significant bits (LSBs) of the mantissa.

Why Triton?

We use triton language because it allows the proposed method to generalize to other accelerators that support Triton. It also greatly speeds up development for this experiment due to its simplicity and the right level of accelerator control. Although Triton is expected to evolve over time, because our implementation is based on Triton’s matmul tutorial, we anticipate the specific code requiring future rewrites will be minimal.

Experiments

A runnable code is provided at the end of this notebook. Here, we adopted a triton matmul kernel from triton tutorial and added a simple truncation function. Since a great amount of details can be found in the original tutorial, we will only highlight the truncation related modifications we made.  Roughly speaking, matmul(A, B) is decomposed into smaller blocks and processed in parallel. Each block of A and B has shapes (BLOCK_SIZE_M, BLOCK_SIZE_K) and (BLOCK_SIZE_K, BLOCK_SIZE_N), respectively. The block-level matmul is computed by Triton’s tl.dot() function, producing a temporary tensor accumulator_inner of shape (BLOCK_SIZE_M, BLOCK_SIZE_N), which assumed to have only Nacc effective mantissa bits.

  1. Truncation of accumulator_inner: We truncated the last Ntrun bits of accumulator_inner using a bit operation with a pre-defined mask. For simplicity, we ignore rounding by setting round_bit= 0.
def prep_round_and_trun_mask(trun_bits):
        round_bit = 1 << (trun_bits - 1) if trun_bits > 0 else 0
        trun_mask = ~tl.cast((1 << trun_bits) - 1, tl.uint32)
        return round_bit, trun_mask
def round_and_trun(x, round_bit, trun_mask):
        """Round and truncate (usually for accumulator)."""
        return libdevice.uint_as_float(
            (libdevice.float_as_uint(x) + round_bit) & trun_mask
 )

2. Accumulation across the K-dimension: Each truncated accumulator_inner was further accumulated into a pre-allocated FP32 tensor accumulator while stepping through K-dimension. The shape of accumulator is the same as accumulator_inner.

3. Writing the results back: After iterating through the K-dimension, the final accumulator values are written back to the corresponding block in target output tensor C, whose shape is (M, N).

Results and discussions

From both Table 1 and Fig. 3 below, we observed that truncating up to 10 least significant mantissa bits of the output (using H100 FP8 TensorCore) produces exactly the same results as the case with no truncation. This indicates that those bits were already 0 in the original output. The experiment therefore suggests that the accumulator is implemented using a special FP22 format (e8m13) for compute efficiency reasons. We repeated this same experiment on an RTX4000-series GPU (Ada Lovelace architecture) and observed the same behavior.

One important consideration we should keep in mind is that this experiment relies on the Triton compiler to translate Triton codes into equivalent CUDA codes. We must ensure that the TensorCore performing the task is indeed the one we intended to inspect, i.e., FP8. In rare situations, the Triton compiler may choose to use FP16 TensorCore instructions for certain FP8 computations. The most reliable way to confirm the actual hardware instructions executed is to use the NVIDIA profiler ncu(3, which is included in cudatoolkit) to inspect the underlying CUDA instructions associated with the Triton tl.dot call.

Readers can save this notebook as a python file and then launch ncu using the following command-line invocation.

/usr/local/cuda-13.0/bin/ncu --target-processes all --set full 
--import-source yes -f --kernel-name matmul_kernel --launch-skip 3 
--launch-count 1 -o ./tl_fp8mm_backend_H100 python 
accumulator_precision_test.py

From ncu profiler readout shown below, we found that FP8xFP8 tl.dot() for the chosen block size (MxNxK=64x64x32) was translated into a QGMMA instruction — an FP8-TensorCore-specific instruction. This confirms that the FP8 TensorCore was indeed used.

As mentioned earlier, the Triton compiler can sometimes choose a different implementation for tl.dot. For example, if we set num_warps = 2 in kernel_config dictionary and repeat the experiment, Triton will pack FP8 into FP16 and use HMMA to perform the computation, where HMMA is a FP16-TensorCore-specific instruction. In this case, the corresponding results show that the accumulator of FP16 TensorCore is only 1 bit shorter than FP32.

Furthermore, since a specialized matmul unit is designed to handle inputs of certain fixed sizes, if BLOCK_SIZE we choose exceeds what TensorCore can handle, the compiler or CUDA library will automatically decompose the operation into several smaller operations. In our triton code, we can increase the BLOCK_SIZE K to 128 and verify with ncu again. We will see that each WGMMA instruction is only capable of dealing with K=32, which means there is an additional summation involved to combine the partial results from multiple TensorCore calls. A natural question is: What precision is used for this intermediate summation? This is the same FP alignment and precision loss problem that we have been discussing. Based on the output from K=128 experiment, we still observe 13 effective mantissa bits. This provides an important insight: if we choose block sizes for the triton kernel that exceed TensorCore’s base design, whether for performance reasons or due to autotuning, there can be additional precision loss from reduced precision summation. Therefore, if matmul precision is a critical concern (especially when training and backward propagation is involved), before falling back to FP16, we should first try to use an intermediate FP32 accumulation as we did in the triton codes. We demonstrated the BLOCK_SIZE_K effect on accuracy here but readers should keep in mind that smaller blocks will impact kernel performance. Readers may want to start from a larger block size, e.g. if autotuning suggests 256 or 512, then gradually reduce to 128, as used in 1, and consider the trade-off between using FP16 and decreasing block size. Alternatively, if using cuBLAS in the custom kernel, CUBLASLT_MATMUL_DESC_FAST_ACCUM flag can achieve the same effect of accumulation precision promotion. 4

Finally, the concept of a reduced-precision accumulator can also be applied to INT8xINT8 engines. The main difference between FP8 and INT8 matmul is that INT8 accumulator truncation occurs on the most significant bits (MSBs) rather than the least significant bits (LSBs). In other words, we need to consider overflow problem instead of underflow as in FP8. Simple modifications to the provided Triton kernel can be made to investigate INT8 behaviors. We leave this exercise to readers who are interested.

Conclusion

We explained the importance of using reduced precision in the accumulator of a matmul engine and demonstrated a simple method to verify the design of our existing accelerator. Understanding of accumulator precision is crucial for users with accuracy sensitive applications who write custom kernels, as well as for hardware designers who need to emulate this behavior for their next generation designs. More importantly, this triton-kernel-based approach can be seamlessly combined with the PyTorch ecosystem, which means the same technique can be extended to other existing and future accelerators that support the Triton language, significantly reducing development time. 

Reference

  1. DeepSeek-V3 Technical Report, Section 3.3.2 Increasing Accumulation Precision. https://arxiv.org/html/2412.19437v1.
  2. SageAttention2, Introduction/Challenge/C2. https://arxiv.org/html/2411.10958v7
  3. ncu website https://docs.nvidia.com/nsight-compute/index.html
  4. https://docs.nvidia.com/cuda/cublas/

Runnable code can be found here

https://gist.github.com/chichun-charlie-liu/88a99949fcbe589aa5f71e48616ac944

]]>
Building Highly Efficient Inference System for Recommenders Using PyTorch https://pytorch.org/blog/building-highly-efficient-inference-system-for-recommenders-using-pytorch/ Thu, 05 Feb 2026 18:00:59 +0000 https://pytorch.org/?p=46650 Why Choose PyTorch for Recommendation System

PyTorch has emerged as the de facto framework in the AI community, with the majority of cutting-edge research, especially in areas like recommendation systems, retrieval, and ranking, being conducted with PyTorch. Developers are eager to bring the latest model advancements into production as quickly as possible. A PyTorch-based recommendation inference system is well-suited to this need, enabling both (1) high efficiency and (2) rapid model adoption in production environments.

In this blog, we will discuss the design of a high-performance recommendation inference system built with PyTorch. Approaches based on these design principles have been thoroughly validated and have successfully served extremely high volumes of traffic, demonstrating strong efficiency and reliability. Our PyTorch-based recommendation inference system serves as the backbone for Meta’s most critical machine learning workloads. Powering global surfaces, including Feed, Ads, Instagram, Reels, Stories, and Marketplace, the system manages a diverse array of ML architectures. These range from sophisticated extensions of the foundational Deep Learning Recommendation Model (DLRM) to cutting-edge, novel modeling techniques such as DHEN (Deep Hierarchical Ensemble Network), HSTU (Hierarchical Sequential Transducer Unit), Wukong, and more.

A Typical Recommendation Research to Production Inference Workflow

The Overall Workflow

After training, a model definition and its trained weights are delivered for inference, establishing a clear contract between the training and inference stages. However, running a training model directly in a production inference environment is highly inefficient and does not meet the performance requirements of real-world applications.

To address this, we need to rapidly and reliably ship trained models to production, while also supporting frequent updates as models are improved or retrained. This dynamic environment—with many models and many versions—demands a robust transformation pipeline that converts trained models into optimized inference models. Such a pipeline ensures that the resulting inference model files are tailored for efficient hardware utilization, enabling high throughput (QPS, i.e., queries per second) and meeting strict latency requirements. In summary, a dedicated system for transforming training models into production-ready inference models is essential for maintaining agility, scalability, and performance in our model deployment process.

Trained Model to Production Inference Transformation Flow

Defining the Inference Model and Weights Mapping

The trained model often includes components that are only necessary during training, such as loss functions and certain regularization techniques. It is best practice to define a dedicated inference model that mirrors the forward logic of the training model, while also allowing for inference-only optimizations. Additionally, a mapping between the inference model’s parameters and the trained model weights (checkpoint) must be established, especially if fully qualified parameter names differ between training and inference. This mapping should be maintained and updated throughout the inference model preparation process.

Capturing the Computation Graph from Python Models

To enable efficient inference, a series of model transformations must be applied to the inference model. Applying these optimizations requires converting PyTorch models defined in Python into a graph representation. Capturing a PyTorch model’s computation graph is a challenging task. Using torch.fx to extract an FX graph is a common practice. This method assumes that the model architecture does not contain cyclic structures. For submodules with complex control flows, these can be marked as leaf nodes to simplify the graph extraction process.

Recently, torch.export has become a more mature tool for capturing computation graphs, offering improved support for models with control flow. However, the resulting PT2IR (a specialized FX graph) can be quite low-level, and decomposed, which may complicate certain model transformations.

Model Transformation and Optimization

After capturing the FX graph, a variety of optimizations can be applied through model transformation passes. Below are some common transformation passes:

  • Model Splitting: For distributed inference scenarios, it is often necessary to split the full “forward” graph into smaller subgraphs. Each subgraph represents the forward pass of a submodule, enabling distributed execution across multiple devices or hosts. Additionally, these transformations can group similar computations together, further enhancing overall efficiency.
  • Operator Fusion: Multiple operations can be replaced with a single, fused implementation to improve efficiency. This can be achieved by swapping submodules or applying graph-level transformations.
  • Quantization: Similar to operator fusion, certain layers (e.g. linear layers) can be replaced with quantized versions to reduce memory usage and improve inference speed. TorchAO provides the support for linear quantization with PT2 support.
  • Compilation (a.k.a. Lowering): Model compilation techniques are typically applied ahead of time as part of the transformation process. This step converts model code into lower-level representations that are better suited for the target inference devices. (See the AI Compiler section below for more details.)

Graph Transformation Example: Full Forward Graph to Split Graph

Model Serialization

Standard PyTorch models use the pickle format for storage, but this approach is insufficient for production due to weak backward compatibility and Python dependency issues. To address these challenges, several serialization solutions are available:

Solution Description Pros Cons
TorchScript Capture TorchScript IR through scripting or tracing, and save as TorchScript format. 1) Mature and strong backward compatibility support

2) Solid control flow support

1) Some constraints on model definition (e.g., no complex data structures)

2) Deprecated and not supported

torch.export Export the PyTorch model as PT2IR. 1) The official way to serialize models in PT2

2) Active development

1) Control flow may need additional handling
torch.package Directly export related Python modules as source code and pickle objects. 1) Great flexibility 1) May require manual effort to define module boundaries

2) Requires Python dependency

Regardless of the serialization format, the resulting artifact should be a zip file. This allows for easy inspection and debugging by unzipping the contents. Processed weights can also be packaged within the zip file. We are prioritizing torch.export for new model development over older tools like TorchScript and torch.package. With TorchScript now being deprecated, torch.export provides a more robust path forward with active feature development, while also providing necessary superior performance compared to torch.package by allowing for a Python-independent runtime.

Model Loading and Execution

Once the inference models are prepared, you will have a set of inference model files. For extremely large models, it may be necessary to load the model structure and weights separately, which could require custom logic for saving and loading.

After loading the model files, the runtime begins processing inference requests. Since PyTorch does not natively provide serving capabilities beyond model execution, an additional server layer is required to manage inference serving. Below, we outline the key features of an efficient and scalable PyTorch inference server for recommendation systems:

Lightweight PyTorch Executor Wrapper

  • The server converts requests to PyTorch model inputs. This wrapper should be minimal to ensure efficiency.

Efficient and Flexible API

  • In a distributed environment, different components of the model communicate via APIs, which necessitates precise semantic definitions, such as specifying the batch dimension and other relevant parameters.
  • Tensor-based APIs align well with the PyTorch model’s forward method.
  • Zero-copy (in-place) APIs allow us to update models in-place, efficiently and seamlessly transitioning from serving one version of a model to the next without requiring significant additional capacity to load both model versions during the transition.

DAG Representation and Executor

  • Modules with similar characteristics (e.g., all embedding bags) can be grouped into dedicated submodules for batch execution.
  • After model splitting, the original forward function is represented as a Directed Acyclic Graph (DAG), with each node corresponding to a submodule. An executor is required to manage the execution of this DAG.
  • DAG nodes may be deployed across multiple hosts, which necessitates support for remote execution. In such cases, an efficient communication library is essential to ensure seamless and performant interactions between distributed components.

Optimizations

In the previous section, we outlined the core principles for building a robust, efficient, and scalable recommendation inference system with PyTorch, one that can handle high traffic volumes and meet stringent production requirements. To further enhance system performance, we will now discuss several key optimization strategies below.

GPU (Accelerator) Inference

With the emergence of new model architectures, computational demands have increased significantly. CPUs often struggle to meet the latency requirements for running such models online, making accelerators like GPUs a natural choice. However, running the entire model on a single GPU can be inefficient, and models may not fit within the memory constraints of a single device. Therefore, splitting models into multiple segments and executing the most compute-intensive layers on GPUs is a practical approach.

Additionally, GPU kernel launch overhead can be substantial. To mitigate this, batching requests together reduces the number of kernel launches and improves overall throughput.

C++ Runtime

While the most straightforward way to run PyTorch models is via Python, the Python runtime introduces noticeable overhead, especially as QPS (queries per second) increases. Typically, Python overhead becomes significant at QPS ≥ 100, and can become a severe bottleneck at QPS ≥ 1000.

For high-QPS scenarios (≥ 100 per host), we recommend using a C++ (or Rust) runtime. Both TorchScript (for TorchScript models) and ExecuTorch (for models saved with torch.export) provide C++ runtimes. Recently, development has focused on a new runtime, torch.nativert, designed for executing torch.export models across servers, as an alternative to the TorchScript runtime, which has been deprecated as of the last PyTorch Conference.

Distributed Inference (DI)

Running the entire inference model as a monolith can be inefficient or even infeasible. Instead, splitting the model into multiple components and distributing them across different workers can both improve efficiency and enable scaling to larger model sizes. Common DI patterns include:

  • CPU-GPU DI: Assign input processing and lightweight computations to CPUs, while offloading compute-heavy layers of the model to GPUs.
  • Embedding-Dense DI: Group embedding tables into dedicated submodules that can be served on separate hosts (similar to traditional parameter servers). Dense layers, which are smaller but compute-intensive, can be grouped and executed together for improved efficiency.
  • Dense Model Parallelism: Split a single dense network into multiple sub-networks that can be executed in parallel, either on different CUDA streams within the same device or across multiple devices, enabling selective lowering and parallel execution.

AI Compiler and High-Performance Kernel Libraries

To achieve maximum performance, developers may be tempted to rewrite model definitions in C++/CUDA and run them directly. However, this approach does not scale well. Instead, AI compilers can automate this process, generating highly optimized artifacts. Options include:

These compilers generate new, compiled artifacts that are packaged alongside the serialized model. For production RecSys deployments, C++ runtimes are preferred for performance reasons. This precludes the use of Python-dependent JIT workflows like torch.compile; instead, Ahead-of-Time (AOT) Inductor is used to precompile models into static runtime artifacts deployable in C++.

AI compilers utilize high-performance kernel libraries to maximize computational efficiency on various hardware platforms, including:

Request Coalescing

To maximize efficiency, requests should be coalesced (batched) together. This requires understanding the semantics of each input, particularly which dimension represents the dynamic batch size, so that requests can be concatenated appropriately. The model’s forward method should be tagged with batch information to facilitate coalescing, and the runtime must support this feature.

Table Batched Embedding

Querying embedding tables in PyTorch can incur significant operator kernel launch overhead, especially when dealing with tens, hundreds, or even thousands of tables. Since embedding lookups are data-transfer-heavy (akin to hash map queries), batching embedding bags and querying all tables in a single call can greatly reduce overhead and improve data transfer efficiency.

Quantization

Both embedding and dense layers of the model can benefit significantly from quantization:

  • Embeddings: Data types like bf16 and int8 are generally safe, and int4 is often acceptable. Different tables and rows may have varying numerical sensitivities. PyTorch supports per-table quantization, even for table-batched embeddings, allowing developers to customize quantization strategies. Some tables may even use int1 or int2 configurations.
  • Dense Layers: Dense layers are more sensitive to quantization. Typically, fp16 and bf16 are acceptable for entire dense submodules, but exceptions exist, such as fp16 may lack sufficient range, and bf16 may not provide enough accuracy. For further efficiency, fp8 and fp4 can be applied at the layer level, though this often requires manual tuning.

All quantization strategies should be validated through accuracy evaluation. TorchAO provides support for Linear and Conv layers, good to start with.

Delta Update

Model freshness is critical for serving recommendation models. As models grow larger, loading the entire model becomes increasingly expensive. A balanced approach is to apply partial weight updates (delta updates). While implementing a protocol for data transfer is straightforward, tuning the weight loading pace is crucial to avoid disrupting serving. Embedding tables are generally more tolerant of partial updates, while dense modules are more sensitive. For dense modules, we recommend using a buffer module to support full module swaps, rather than updating weights individually.

Developer Experience

Python Runtime

To streamline the development and debugging of the inference flow, we recommend providing a lightweight Python runtime environment (versus using the C++ runtime). This approach allows developers to efficiently determine whether issues originate from the runtime or the model itself. Additionally, it simplifies the process of adding instrumentation for debugging purposes.

With the introduction of free-threaded Python, both runtime and communication overhead can be further minimized within the Python ecosystem. This advancement also makes deploying Python runtimes in production environments increasingly practical.

Module Swap-Based Transformations

Historically, graph-based transformations have been challenging for model authors to understand and debug, largely due to the complexity of graph manipulations and the loss of original stack trace information. To address these issues, we recommend shifting such optimizations earlier in the inference module authoring process. By adopting a holistic, native PyTorch module-based workflow, and leveraging eager mode transformations, we have found that the inference development experience is significantly improved.

Eval Flow

To ensure both model and runtime quality, we recommend implementing the following two evaluation flows:

  • Accuracy Verification: Compare the inference model’s quality against training evaluation results.
  • Performance Benchmarking: Replay production-like traffic to assess throughput and latency.

Conclusion

At Meta, we developed a highly efficient recommendation inference system built on PyTorch that is critical for translating cutting-edge research into production-grade services. This blog detailed a robust workflow, starting from a trained model definition and its weights, progressing through essential inference transformation steps, including graph capture, model splitting, optimizations (fusion, quantization, compilation, etc.), and finally serialization. We then outlined the requirements for a high-performance inference server, emphasizing a lightweight executor, flexible tensor-based APIs, and a DAG-based model execution model. Finally, we explored advanced optimization techniques crucial for high-QPS, low-latency performance, such as leveraging GPU/Accelerator inference, adopting a C++ runtime, implementing Distributed Inference patterns, utilizing AI compilers, and applying sophisticated methods like request coalescing, Table Batched Embeddings, and quantization. By adhering to these principles and utilizing the featured open-source libraries, developers can build scalable, performant, and agile PyTorch-based systems capable of serving the world’s most demanding ML recommendation workloads.

Related Libraries

TorchRec: A PyTorch domain library that powers Meta’s production recommender systems by providing the sparsity and parallelism primitives necessary to train and deploy models with massive embedding tables sharded across multiple GPUs.

TorchAO: TorchAO is an easy to use quantization library for native PyTorch. TorchAO works out-of-the-box with torch.compile() and FSDP2 across most HuggingFace PyTorch models.

AITemplate: An open-source Python framework that transforms deep neural networks into highly optimized C++ code for NVIDIA and AMD GPUs, delivering near-roofline inference performance through unified hardware support and comprehensive operator fusion.

TensorRT: NVIDIA TensorRT is a developer ecosystem comprising inference compilers, runtimes, and model optimizations designed to deliver high-performance, low-latency deep learning inference for production applications.

Generative Recommenders / HSTU: A library reformulates classical recommendation systems as generative models and introduces algorithms like HSTU and M-FALCON to drastically accelerate training and inference while establishing scaling laws for billion-user scale environments.

FBGEMM: Highly-optimized kernels used across deep learning applications, including recommendation systems.

Triton and Low-Level Extension (TLX): Triton is a Python-based language and compiler designed for writing highly efficient GPU kernels. TLX (Triton Low-Level Extensions) is an experimental add-on that provides fine-grained, hardware-specific control within Triton, enabling developers to further optimize performance on modern GPUs.

oneDNN: oneAPI Deep Neural Network Library is an open-source, cross-platform performance library of basic building blocks for deep learning applications, specifically optimized for Intel processors.

ZenDNN: ZenDNN (Zen Deep Neural Network) Library accelerates deep learning inference applications on AMD CPUs.

CUTLASS / CuTeDSL: CUTLASS is a collection of abstractions for implementing high-performance matrix-matrix multiplication (GEMM) and related computations at all levels and scales within CUDA. CuteDSL is a Python-based embedded domain-specific language for Cutlass.

AITER: AITER is AMD’s centralized repository that supports various high performance AI operators for AI workloads acceleration, where a good unified place for all the customer operator-level requests, which can match different customers’ needs.

CK: The Composable Kernel (CK) library provides a programming model for writing performance-critical kernels for machine learning workloads across multiple architectures (GPUs, CPUs, etc.). The CK library uses general purpose kernel languages, such as HIP C++.

]]>