Conversation
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
This reverts commit d9ff566. Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Greptile SummaryThis PR adds MXFP8 (Microscaling FP8) attention support to TransformerEngine, introducing a new quantization scheme with block-level scaling for attention operations. Key Changes
Critical Issues Found
ArchitectureThe implementation splits QK and V dimensions ( Confidence Score: 2/5
Important Files Changed
Last reviewed commit: d6ecadc |
| print(f"fused_attn_bwd_fp8[{i}].max(): {fused_attn_bwd_fp8[i].max()}, fused_attn_bwd_f16[{i}].max(): {fused_attn_bwd_f16[i].max()}") | ||
| print(f"fused_attn_bwd_fp8[{i}].min(): {fused_attn_bwd_fp8[i].min()}, fused_attn_bwd_f16[{i}].min(): {fused_attn_bwd_f16[i].min()}") | ||
| # compare_and_assert( | ||
| # fused_attn_bwd_fp8[i], | ||
| # fused_attn_bwd_f16[i], | ||
| # f"fused_attn_bwd_fp8[{i}]", | ||
| # f"fused_attn_bwd_f16[{i}]", | ||
| # atol, | ||
| # rtol, | ||
| # rmse_tol, | ||
| # True, |
There was a problem hiding this comment.
Backward pass assertions commented out and replaced with debug prints - tests won't catch regressions
| print(f"fused_attn_bwd_fp8[{i}].max(): {fused_attn_bwd_fp8[i].max()}, fused_attn_bwd_f16[{i}].max(): {fused_attn_bwd_f16[i].max()}") | |
| print(f"fused_attn_bwd_fp8[{i}].min(): {fused_attn_bwd_fp8[i].min()}, fused_attn_bwd_f16[{i}].min(): {fused_attn_bwd_f16[i].min()}") | |
| # compare_and_assert( | |
| # fused_attn_bwd_fp8[i], | |
| # fused_attn_bwd_f16[i], | |
| # f"fused_attn_bwd_fp8[{i}]", | |
| # f"fused_attn_bwd_f16[{i}]", | |
| # atol, | |
| # rtol, | |
| # rmse_tol, | |
| # True, | |
| compare_and_assert( | |
| fused_attn_bwd_fp8[i], | |
| fused_attn_bwd_f16[i], | |
| f"fused_attn_bwd_fp8[{i}]", | |
| f"fused_attn_bwd_f16[{i}]", | |
| atol, | |
| rtol, | |
| rmse_tol, | |
| True, | |
| ) |
| q_fp8, k_fp8, v_fp8 = (None, None, None) | ||
| # communicate for the 'a2a' part of 'a2a+p2p' | ||
| if cp_size_a2a > 1: | ||
| print(f">>>>>>======================>>>>>> {torch.cuda.current_device()}: fp8: {fp8}, is_input_fp8: {is_input_fp8}, fp8_recipe.mxfp8(): {fp8_recipe.mxfp8()}") |
There was a problem hiding this comment.
Debug print statement left in production code
| print(f">>>>>>======================>>>>>> {torch.cuda.current_device()}: fp8: {fp8}, is_input_fp8: {is_input_fp8}, fp8_recipe.mxfp8(): {fp8_recipe.mxfp8()}") |
| # fp8 attention: q, k, v: torch.Tensor, dtype=torch.uint8 | ||
| out = None | ||
| o_format = qkv_format | ||
| for i in range(cp_size + 1): |
There was a problem hiding this comment.
Debug print statement left in production code
| for i in range(cp_size + 1): |
| softmax_lse_per_step[0], | ||
| seq_dim, | ||
| ) | ||
| print(f"====o/v===== {torch.cuda.current_device()}: i: {i}, {enable_mla}, out.shape: {out.shape} {out_per_step[0].shape} {v_shape} {o_shape}") |
There was a problem hiding this comment.
Debug print statement left in production code
| print(f"====o/v===== {torch.cuda.current_device()}: i: {i}, {enable_mla}, out.shape: {out.shape} {out_per_step[0].shape} {v_shape} {o_shape}") |
| dO_quantizer.set_usage(rowwise=True, columnwise=False) | ||
| dO_quantizer.internal = True | ||
|
|
||
| dP_quantizer = quantizers["scaling_bwd"][META_DP] |
There was a problem hiding this comment.
Debug print statement in combine_and_quantize will spam logs in production
| dP_quantizer = quantizers["scaling_bwd"][META_DP] |
|
|
||
| os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0" | ||
| os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1" | ||
| # os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1" |
There was a problem hiding this comment.
NVTE_ALLOW_NONDETERMINISTIC_ALGO commented out - could affect test behavior or cause failures
| # os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1" | |
| os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1" |
| dP_quantizer.interal = True | ||
| dP_quantizer.set_usage(rowwise=True, columnwise=False) | ||
|
|
||
| dQKV_quantizer = quantizers["scaling_bwd"][META_DQKV] |
There was a problem hiding this comment.
Hard assertions on sequence/dimension alignment may fail in valid edge cases - consider providing clearer error messages
Additional Comments (2)
|
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: