Skip to content

Add MXFP8 attention#2719

Draft
cyanguwa wants to merge 50 commits intoNVIDIA:mainfrom
cyanguwa:add_mxfp8
Draft

Add MXFP8 attention#2719
cyanguwa wants to merge 50 commits intoNVIDIA:mainfrom
cyanguwa:add_mxfp8

Conversation

@cyanguwa
Copy link
Collaborator

@cyanguwa cyanguwa commented Mar 1, 2026

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
This reverts commit d9ff566.

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
@cyanguwa cyanguwa marked this pull request as ready for review March 1, 2026 21:43
@cyanguwa cyanguwa marked this pull request as draft March 1, 2026 21:43
@cyanguwa cyanguwa closed this Mar 1, 2026
@cyanguwa cyanguwa deleted the add_mxfp8 branch March 1, 2026 21:44
@cyanguwa cyanguwa restored the add_mxfp8 branch March 1, 2026 21:50
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 1, 2026

Greptile Summary

This PR adds MXFP8 (Microscaling FP8) attention support to TransformerEngine, introducing a new quantization scheme with block-level scaling for attention operations.

Key Changes

  • Added MXFP8 scaling mode with FP8_E8M0 tensor reordering and block-level quantization
  • Introduced new BHSD_BHSD_BHSD layout for MXFP8 format with proper stride calculations
  • Extended fused attention APIs to support separate output formats (o_format, d_out_format, dqkv_layout)
  • Updated cuDNN frontend submodule for MXFP8 support (requires cuDNN >= 9.21.0, SM >= 10.0)
  • Added MXFP8BlockScaling recipe alongside existing DelayedScaling and CurrentScaling

Critical Issues Found

  • Multiple debug print statements left in production code across context_parallel.py, utils.py, grouped_tensor.py, and test files - will spam logs
  • Commented-out test assertions in test_attention.py (lines 2241-2251) - backward pass validation disabled, tests won't catch regressions
  • Commented-out environment variable NVTE_ALLOW_NONDETERMINISTIC_ALGO in tests - may cause test failures or unexpected behavior

Architecture

The implementation splits QK and V dimensions (d_qk vs d_v) to support Multi-Latent Attention (MLA), adds GroupedTensor storage for MXFP8 quantization with columnwise scaling, and extends backend selection logic to route MXFP8 workloads to appropriate kernels based on hardware capabilities.

Confidence Score: 2/5

  • Not safe to merge - contains debug code and disabled test assertions
  • Score reflects critical issues: production code has debug print statements that will spam logs, and backward pass test assertions are commented out meaning the tests won't catch regressions in MXFP8 backward pass correctness
  • Pay close attention to tests/pytorch/attention/test_attention.py (disabled assertions), transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py (3 print statements), and transformer_engine/pytorch/attention/dot_product_attention/utils.py (print + hard assertions)

Important Files Changed

Filename Overview
tests/pytorch/attention/test_attention.py Test file with debug print statements and commented-out backward pass assertions for MXFP8 testing
transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py Major changes for MXFP8 support with multiple debug print statements left in code
transformer_engine/pytorch/attention/dot_product_attention/utils.py MXFP8 quantizer setup and format conversion with debug print statement in combine_and_quantize
transformer_engine/common/fused_attn/fused_attn_fp8.cu Core CUDA implementation for MXFP8 attention with FP8_E8M0 tensor reordering
transformer_engine/common/fused_attn/utils.h Added BHSD_BHSD_BHSD layout support with proper stride calculations
transformer_engine/pytorch/tensor/storage/grouped_tensor.py GroupedTensor storage with debug print statements during quantization

Last reviewed commit: d6ecadc

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

33 files reviewed, 9 comments

Edit Code Review Agent Settings | Greptile

Comment on lines 2241 to 2251
print(f"fused_attn_bwd_fp8[{i}].max(): {fused_attn_bwd_fp8[i].max()}, fused_attn_bwd_f16[{i}].max(): {fused_attn_bwd_f16[i].max()}")
print(f"fused_attn_bwd_fp8[{i}].min(): {fused_attn_bwd_fp8[i].min()}, fused_attn_bwd_f16[{i}].min(): {fused_attn_bwd_f16[i].min()}")
# compare_and_assert(
# fused_attn_bwd_fp8[i],
# fused_attn_bwd_f16[i],
# f"fused_attn_bwd_fp8[{i}]",
# f"fused_attn_bwd_f16[{i}]",
# atol,
# rtol,
# rmse_tol,
# True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Backward pass assertions commented out and replaced with debug prints - tests won't catch regressions

Suggested change
print(f"fused_attn_bwd_fp8[{i}].max(): {fused_attn_bwd_fp8[i].max()}, fused_attn_bwd_f16[{i}].max(): {fused_attn_bwd_f16[i].max()}")
print(f"fused_attn_bwd_fp8[{i}].min(): {fused_attn_bwd_fp8[i].min()}, fused_attn_bwd_f16[{i}].min(): {fused_attn_bwd_f16[i].min()}")
# compare_and_assert(
# fused_attn_bwd_fp8[i],
# fused_attn_bwd_f16[i],
# f"fused_attn_bwd_fp8[{i}]",
# f"fused_attn_bwd_f16[{i}]",
# atol,
# rtol,
# rmse_tol,
# True,
compare_and_assert(
fused_attn_bwd_fp8[i],
fused_attn_bwd_f16[i],
f"fused_attn_bwd_fp8[{i}]",
f"fused_attn_bwd_f16[{i}]",
atol,
rtol,
rmse_tol,
True,
)

q_fp8, k_fp8, v_fp8 = (None, None, None)
# communicate for the 'a2a' part of 'a2a+p2p'
if cp_size_a2a > 1:
print(f">>>>>>======================>>>>>> {torch.cuda.current_device()}: fp8: {fp8}, is_input_fp8: {is_input_fp8}, fp8_recipe.mxfp8(): {fp8_recipe.mxfp8()}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Debug print statement left in production code

Suggested change
print(f">>>>>>======================>>>>>> {torch.cuda.current_device()}: fp8: {fp8}, is_input_fp8: {is_input_fp8}, fp8_recipe.mxfp8(): {fp8_recipe.mxfp8()}")

# fp8 attention: q, k, v: torch.Tensor, dtype=torch.uint8
out = None
o_format = qkv_format
for i in range(cp_size + 1):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Debug print statement left in production code

Suggested change
for i in range(cp_size + 1):

softmax_lse_per_step[0],
seq_dim,
)
print(f"====o/v===== {torch.cuda.current_device()}: i: {i}, {enable_mla}, out.shape: {out.shape} {out_per_step[0].shape} {v_shape} {o_shape}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Debug print statement left in production code

Suggested change
print(f"====o/v===== {torch.cuda.current_device()}: i: {i}, {enable_mla}, out.shape: {out.shape} {out_per_step[0].shape} {v_shape} {o_shape}")

dO_quantizer.set_usage(rowwise=True, columnwise=False)
dO_quantizer.internal = True

dP_quantizer = quantizers["scaling_bwd"][META_DP]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Debug print statement in combine_and_quantize will spam logs in production

Suggested change
dP_quantizer = quantizers["scaling_bwd"][META_DP]


os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1"
# os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NVTE_ALLOW_NONDETERMINISTIC_ALGO commented out - could affect test behavior or cause failures

Suggested change
# os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1"
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1"

Comment on lines 2126 to +2129
dP_quantizer.interal = True
dP_quantizer.set_usage(rowwise=True, columnwise=False)

dQKV_quantizer = quantizers["scaling_bwd"][META_DQKV]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hard assertions on sequence/dimension alignment may fail in valid edge cases - consider providing clearer error messages

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 1, 2026

Additional Comments (2)

transformer_engine/pytorch/tensor/storage/grouped_tensor.py
Debug print statement left in production code


tests/pytorch/attention/run_attention_with_cp.py
Debug print statement left in test script

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
@cyanguwa cyanguwa reopened this Mar 1, 2026
pre-commit-ci bot and others added 6 commits March 1, 2026 22:36
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant