Skip to content

Feature/unswizzle#2732

Open
int-smart wants to merge 14 commits intoNVIDIA:mainfrom
int-smart:feature/unswizzle
Open

Feature/unswizzle#2732
int-smart wants to merge 14 commits intoNVIDIA:mainfrom
int-smart:feature/unswizzle

Conversation

@int-smart
Copy link

@int-smart int-smart commented Mar 4, 2026

Description

This PR adds unswizzle support for scaling factors and extends the swizzle module so scaling tensors can be converted from GEMM-swizzled layout back to compact layout, including multi-tensor paths. It also adds round-trip and standalone tests to validate unswizzle correctness.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Added unswizzle APIs and implementation in transformer_engine/common/swizzle/swizzle.cu and declarations in transformer_engine/common/include/transformer_engine/swizzle.h
  • Added multi-tensor unswizzle support with swizzle-like validation assumptions (homogeneous scaling mode/layout, swizzled input and compact output expectations)
  • Refactored multi-tensor unswizzle launch/kernels to mirror swizzle structure (split row-wise and column-wise kernels) for easier readability
  • Added/extended tests in tests/cpp/operator/test_swizzle.cu, including standalone unswizzle and swizzle→unswizzle round-trip coverage

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

int-smart and others added 6 commits March 3, 2026 20:40
- Introduced `nvte_unswizzle_scaling_factors` to convert swizzled scaling factors back to row-major format.
- Implemented `regs_unshuffle_with_bit_shifts` and `regs_unshuffle` for unshuffling operations in CUDA kernels.
- Added `unswizzle_row_scaling_kernel_impl` and `unswizzle_col_scaling_kernel_impl` for handling unswizzling in row and column scaling respectively.

These changes enhance the functionality of the swizzle module, enabling better handling of scaling factors in tensor operations.

Signed-off-by: Abhishek <abhi.dtu11@gmail.com>
These enhancements tests the changes introduced for unswizzling

Signed-off-by: Abhishek <abhi.dtu11@gmail.com>
- Introduced `compute_ref_unswizzle` to handle the conversion of swizzled scaling factors back to their original format.
- Added `performTestUnswizzle1D` to validate the unswizzling process with various scaling modes.
- Created `UnswizzleTestSuite` for comprehensive testing of unswizzling operations.

Signed-off-by: Abhishek <abhi.dtu11@gmail.com>
- Moved the definition of `swizzle_row_scaling_kernel` to a new location for better organization.
- Ensured the kernel implementation is now properly defined and accessible for scaling operations in the swizzle module.

Signed-off-by: Abhishek <abhi.dtu11@gmail.com>
- Introduced `multi_tensor_unswizzle_scaling_factors` to convert swizzled scaling factors back to their original row-major format.
- Implemented CUDA kernels for unswizzling in both row and column scaling, enhancing the swizzle module's functionality.
- Updated the launch function to handle multiple tensor unswizzling operations efficiently.

These changes improve the handling of scaling factors in tensor operations, ensuring better performance and organization within the swizzle module.

Signed-off-by: Abhishek <abhi.dtu11@gmail.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 4, 2026

Greptile Summary

This PR adds nvte_unswizzle_scaling_factors and nvte_multi_tensor_unswizzle_scaling_factors — the inverse of the existing swizzle APIs — allowing GEMM-swizzled scaling factors to be converted back to compact layout. It mirrors the swizzle structure with separate row-wise and column-wise GPU kernels and a multi-tensor batching path.

Key open issues (carried over from review threads, not yet addressed):

  • Dual-scale tensors rejected in both unswizzle and swizzle (lines 1168, 1373): NVTE_CHECK(!has_rowwise_scale_inv || !has_columnwise_scale_inv, ...) prevents callers with both row- and column-wise scales from using either function. The separate kernels already handle both cases independently; the check appears to be an arbitrary limitation.
  • Uninitialized variables in test skip messages (test lines 158–160, 323–325): When exactly one of rowwise/columnwise is true, the other SF_MODE_* variable is never written before being passed to std::to_string() — undefined behaviour.
  • Rowwise output size check is trivially true (swizzle.cu line 1291): m and k are derived from output->scale_inv.shape, so m * k == output->scale_inv.numel() is always satisfied, providing no validation. Intended to guard against mismatched input/output buffer sizes, this should compare against the swizzled input's element count instead.
  • original_m_list/original_k_list populated but unused in multi_tensor_unswizzle_scaling_factors: the unswizzle kernels do not read those fields, creating a misleading parallel with the swizzle kernels (which need them for padding masks).
  • Round-trip test only covers aligned matrix dimensions: performTestSwizzleUnswizzleRoundtrip uses the existing num_tiles vector (all exact multiples of 128); non-aligned M/K round-trip correctness remains untested.

Confidence Score: 2/5

  • Not safe to merge — several correctness issues remain open, including a trivially-true validation check that masks mismatched buffer sizes and undefined behaviour in test skip messages.
  • The unswizzle kernel logic itself appears structurally sound and mirrors the swizzle implementation correctly. However, the output-size validation check (line 1291) is trivially always-true, so mismatched input/output buffer sizes would silently go undetected. UB from uninitialized variables in test skip messages (lines 158–160, 323–325) could cause non-deterministic test behaviour on some compilers. Several issues flagged in review threads — dual-scale rejection, unused original_m_list/original_k_list, and round-trip tests limited to aligned dimensions — remain unaddressed.
  • transformer_engine/common/swizzle/swizzle.cu (validation logic at lines 1291 and 1327), tests/cpp/operator/test_swizzle.cu (uninitialized variables in skip messages at lines 158–160 and 323–325)

Important Files Changed

Filename Overview
transformer_engine/common/swizzle/swizzle.cu Adds unswizzle kernel implementations (row and column single- and multi-tensor paths). Several open correctness issues remain: the single-tensor unswizzle output-size check (line 1291) is trivially true since m and k are derived from output->scale_inv.shape, providing no actual validation; both single- and multi-tensor unswizzle still reject dual-scale tensors (MXFP8 with both row- and column-wise scales) while swizzle_scaling_factors also has the same restriction (lines 599-600), but round-trip callers expect symmetry. The unswizzle_col_scaling_kernel_impl SLM load reads k_tiles_in_tb tiles contiguously per M-tile, which is valid because the swizzle kernel writes them contiguously.
tests/cpp/operator/test_swizzle.cu Adds unswizzle standalone tests (with padded shapes) and a swizzle→unswizzle round-trip suite. Uninitialized-variable UB in skip messages (lines 158-160 and 323-325) remains unfixed: when exactly one of rowwise/columnwise is true, the other SF_MODE variable is uninitialized yet referenced in the GTEST_SKIP string. The round-trip suite re-uses the existing num_tiles vector (all exact multiples of 128), so non-aligned M/K round-trip behaviour is not exercised.
transformer_engine/common/include/transformer_engine/swizzle.h Adds nvte_unswizzle_scaling_factors and nvte_multi_tensor_unswizzle_scaling_factors declarations with clear documentation. No issues found.

Sequence Diagram

sequenceDiagram
    participant Caller
    participant nvte_unswizzle_scaling_factors
    participant unswizzle_scaling_factors
    participant unswizzle_scaling_kernel
    participant nvte_multi_tensor_unswizzle
    participant multi_tensor_unswizzle_scaling_factors
    participant launch_multi_tensor_unswizzle

    Caller->>nvte_unswizzle_scaling_factors: input (swizzled), output (compact)
    nvte_unswizzle_scaling_factors->>unswizzle_scaling_factors: validate scaling mode, flags

    note over unswizzle_scaling_factors: Rejects dual-scale tensors (both row+col)<br/>(same restriction as swizzle path)

    alt rowwise_unswizzle
        unswizzle_scaling_factors->>unswizzle_scaling_kernel: unswizzle_row_scaling_kernel_impl<br/>swizzled input → compact row-major
    end
    alt columnwise_unswizzle
        unswizzle_scaling_factors->>unswizzle_scaling_kernel: unswizzle_col_scaling_kernel_impl<br/>swizzled input → compact K-major
    end

    Caller->>nvte_multi_tensor_unswizzle: inputs[] (swizzled), outputs[] (compact)
    nvte_multi_tensor_unswizzle->>multi_tensor_unswizzle_scaling_factors: validate per-tensor
    multi_tensor_unswizzle_scaling_factors->>launch_multi_tensor_unswizzle: batch rowwise path
    launch_multi_tensor_unswizzle->>multi_tensor_unswizzle_row_scaling_kernel: <<<blocks>>> kernel args struct
    multi_tensor_unswizzle_scaling_factors->>launch_multi_tensor_unswizzle: batch columnwise path
    launch_multi_tensor_unswizzle->>multi_tensor_unswizzle_col_scaling_kernel: <<<blocks>>> kernel args struct
Loading

Last reviewed commit: bc1fb51

@vthumbe1503 vthumbe1503 added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label Mar 4, 2026
Signed-off-by: Abhishek <abhi.dtu11@gmail.com>
@int-smart int-smart force-pushed the feature/unswizzle branch from 85ea04b to 17dbb33 Compare March 5, 2026 02:13
int-smart and others added 2 commits March 4, 2026 18:49
@ptrendx
Copy link
Member

ptrendx commented Mar 11, 2026

@int-smart Please address the comments from Greptile and ideally also add the test case with the input not already padded to 128,128.

@int-smart
Copy link
Author

@ptrendx Will look into these

@int-smart
Copy link
Author

@ptrendx From what I am understanding then, there is no relevance of padding to the unswizzle kernel. Since the padding is already done during the swizzling operation I can just mirror it back to compact layout with the zero pads correctly in the compact layout and that should do. Is that assumption correct. Initially I was thinking of removing the padding from the scale_inv itself since this would be used for checkpointing

int-smart and others added 2 commits March 12, 2026 19:53
- Updated unswizzling kernel implementations to remove original_M and original_K parameters, simplifying the function signatures.
- Enhanced test suite to utilize new unswizzling data shapes, ensuring comprehensive coverage of aligned and padded cases.

These changes improve the clarity and efficiency of the unswizzling process in the swizzle module.
Signed-off-by: Abhishek <abhi.dtu11@gmail.com>
Comment on lines +1163 to +1166
const bool has_rowwise_scale_inv = input->scale_inv.has_data();
const bool has_columnwise_scale_inv = input->columnwise_scale_inv.has_data();
NVTE_CHECK(!has_rowwise_scale_inv || !has_columnwise_scale_inv,
"Input tensor has both row-wise and column-wise scaling factors");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Asymmetric handling of dual-scale tensors breaks round-trip correctness

unswizzle_scaling_factors explicitly rejects tensors that have both rowwise and columnwise scaling factors (line 1165–1166), but the counterpart swizzle_scaling_factors happily processes both scale types in a single call (it runs both the rowwise and columnwise swizzle paths sequentially).

This means calling the public round-trip pair —

nvte_swizzle_scaling_factors(input, swizzled, stream);   // succeeds: handles both scales
nvte_unswizzle_scaling_factors(swizzled, output, stream); // FAILS: "Input tensor has both..."

— will raise a runtime error for any MXFP8 tensor that carries both rowwise and columnwise scale factors (a common configuration in dual-path training).

The same asymmetry is present in the multi-tensor variant (multi_tensor_unswizzle_scaling_factors, line 1391–1392).

The fix is either:

  1. Support both scale types in the unswizzle path (mirror swizzle_scaling_factors), or
  2. Document the restriction in the header API comment so callers know to split the tensor or call two separate unswizzle invocations.

As-is, a user who relies on swizzleunswizzle being a perfect inverse pair for the general case will encounter a silent API contract violation.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, this limitation seems arbitrary - you have all of the pieces to handle both row- and columnwise unswizzle, so why not call those 2 kernels if you need to do both?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're absolutely right. The implementation already has both the row-wise and column-wise unswizzle kernels, and they execute independently in the if (rowwise_unswizzle) and if (columnwise_unswizzle) blocks. The check is arbitrary.

Replace line 1165-1166:

Suggested change
const bool has_rowwise_scale_inv = input->scale_inv.has_data();
const bool has_columnwise_scale_inv = input->columnwise_scale_inv.has_data();
NVTE_CHECK(!has_rowwise_scale_inv || !has_columnwise_scale_inv,
"Input tensor has both row-wise and column-wise scaling factors");
NVTE_CHECK(has_rowwise_scale_inv || has_columnwise_scale_inv,
"Input tensor has neither row-wise nor column-wise scaling factors");

This mirrors the logic in swizzle_scaling_factors (line 472) and allows the two kernel paths to run sequentially when both scale types are present, making the round-trip work correctly for dual-scale tensors.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right - there's no technical reason for the limitation. The code already has separate rowwise and columnwise unswizzle kernels, so it should just invoke both when a tensor has both scale types, exactly like swizzle_scaling_factors does.

The fix is straightforward:

Suggested change
const bool has_rowwise_scale_inv = input->scale_inv.has_data();
const bool has_columnwise_scale_inv = input->columnwise_scale_inv.has_data();
NVTE_CHECK(!has_rowwise_scale_inv || !has_columnwise_scale_inv,
"Input tensor has both row-wise and column-wise scaling factors");
const bool has_rowwise_scale_inv = input->scale_inv.has_data();
const bool has_columnwise_scale_inv = input->columnwise_scale_inv.has_data();
NVTE_CHECK(has_rowwise_scale_inv || has_columnwise_scale_inv,
"Input tensor has no scaling factors to unswizzle");

Then update the conditional logic below to run:

  • Rowwise unswizzle kernel if has_rowwise_scale_inv
  • Columnwise unswizzle kernel if has_columnwise_scale_inv
  • Both kernels if both are present

This makes the unswizzle path symmetric with swizzle and fixes the round-trip breakage.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ptrendx I basically borrowed this behavior from the swizzle scaling factors: line ~599
NVTE_CHECK(!has_rowwise_scale_inv || !has_columnwise_scale_inv,
"Input tensor has both row-wise and column-wise scaling factors");
if (!has_rowwise_scale_inv && !has_columnwise_scale_inv) {
return;
}
What was the purpose of this check there. I can then try and mend this check in unswizzling if I know more on it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at the swizzle code, the check at line 599 appears to be overly conservative - the implementation is already designed to handle both scale types.

After the check, the code explicitly sets up two independent execution paths (lines 659-677):

bool rowwise_swizzle{false}, columnwise_swizzle{false};
switch (scaling_mode) {
  case NVTE_MXFP8_1D_SCALING: {
    rowwise_swizzle = has_rowwise_scale_inv;
    columnwise_swizzle = has_columnwise_scale_inv;
    break;
  }
  // ...
}

// Perform row-wise swizzle
if (rowwise_swizzle) { /* ... launch rowwise kernel ... */ }

// Perform column-wise swizzle  
if (columnwise_swizzle) { /* ... launch columnwise kernel ... */ }

The two kernels execute independently and write to separate output buffers (output->scale_inv.dptr vs output->columnwise_scale_inv.dptr), so there's no technical conflict.

The check was likely added as a defensive measure or reflects an earlier design constraint that was later relaxed. Since the implementation already supports sequential execution of both kernel types, you should remove the check in both swizzle and unswizzle to enable proper round-trip support for dual-scale tensors.

@ptrendx
Copy link
Member

ptrendx commented Mar 16, 2026

@int-smart I'm not sure I follow, I think that what you are saying is probably correct, but let me try to clarify just in case:

  • the scaling factors, irrespective of the compact or gemm-ready layout, are zero-padded to the multiple of [128,4] (or the transpose in case of compact and columnwise).
  • So for the unswizzle, you should just use the same size of the output unswizzled tensor as the original swizzled one. You don't even need to zero it before unswizzling, since the swizzled tensor already has 0s in the right places so unswizzling it will put 0s in the pad positions.

@int-smart
Copy link
Author

@ptrendx Makes sense. I added that in the last commit.

Comment on lines +1150 to +1240
switch (scaling_mode) {
case NVTE_MXFP8_1D_SCALING:
NVTE_CHECK(is_fp8_dtype(input->dtype()), "Input tensor has invalid dtype (expected FP8, got ",
to_string(input->dtype()), ").");
break;
case NVTE_NVFP4_1D_SCALING:
NVTE_CHECK(is_fp4_dtype(input->dtype()), "Input tensor has invalid dtype (expected FP4, got ",
to_string(input->dtype()), ").");
break;
default:
NVTE_ERROR("Invalid scaling mode");
}

const bool has_rowwise_scale_inv = input->scale_inv.has_data();
const bool has_columnwise_scale_inv = input->columnwise_scale_inv.has_data();
NVTE_CHECK(!has_rowwise_scale_inv || !has_columnwise_scale_inv,
"Input tensor has both row-wise and column-wise scaling factors");
if (!has_rowwise_scale_inv && !has_columnwise_scale_inv) {
return;
}

int m{0}, k{0};
switch (scaling_mode) {
case NVTE_MXFP8_1D_SCALING: {
if (has_rowwise_scale_inv) {
NVTE_CHECK(input->scale_inv.shape.size() == 2,
"Expected 2D scaling factors, got shape=", input->scale_inv.shape, ".");
m = input->scale_inv.shape[0];
k = input->scale_inv.shape[1];
} else if (has_columnwise_scale_inv) {
NVTE_CHECK(input->columnwise_scale_inv.shape.size() == 2,
"Expected 2D scaling factors, got shape=", input->columnwise_scale_inv.shape,
".");
m = input->columnwise_scale_inv.shape[1];
k = input->columnwise_scale_inv.shape[0];
}
break;
}
case NVTE_NVFP4_1D_SCALING: {
if (has_rowwise_scale_inv) {
NVTE_CHECK(input->scale_inv.shape.size() == 2,
"Expected 2D scaling factors, got shape=", input->scale_inv.shape, ".");
m = input->scale_inv.shape[0];
k = input->scale_inv.shape[1];
} else if (has_columnwise_scale_inv) {
NVTE_CHECK(input->columnwise_scale_inv.shape.size() == 2,
"Expected 2D scaling factors, got shape=", input->columnwise_scale_inv.shape,
".");
m = input->columnwise_scale_inv.shape[0];
k = input->columnwise_scale_inv.shape[1];
}
break;
}
default:
NVTE_ERROR("Invalid scaling mode");
}

constexpr int SF_TILE_DIM_M = 128;
constexpr int SF_TILE_DIM_K = 4;
NVTE_CHECK(m % SF_TILE_DIM_M == 0, "Input should be padded in M/N dimension!");
NVTE_CHECK(k % SF_TILE_DIM_K == 0, "Input should be padded in K dimension!");

if (has_rowwise_scale_inv) {
NVTE_CHECK(output->scale_inv.has_data(),
"Output tensor does not have row-wise scaling factors.");
}
if (has_columnwise_scale_inv) {
NVTE_CHECK(output->columnwise_scale_inv.has_data(),
"Output tensor does not have column-wise scaling factors.");
}

bool rowwise_unswizzle{false}, columnwise_unswizzle{false};
switch (scaling_mode) {
case NVTE_MXFP8_1D_SCALING: {
rowwise_unswizzle = has_rowwise_scale_inv;
columnwise_unswizzle = has_columnwise_scale_inv;
break;
}
case NVTE_NVFP4_1D_SCALING: {
rowwise_unswizzle = true;
columnwise_unswizzle = false;
break;
}
default:
NVTE_ERROR("Invalid scaling mode");
}

const dim3 block_size(TB_DIM, TB_DIM);
const int num_tiles_m = m / SF_TILE_DIM_M;
const int num_tiles_k = k / SF_TILE_DIM_K;

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code is pretty convoluted here and it doesn't have to be. There are some pieces there that you could do at the beginning without looking at the scaling factor (like checking whether the input has scale_inv/columnwise_scale_inv and checking if the output has them too). For the rest I would say that avoiding code duplication here is not worth breaking of the flow of NVFP4/MXFP8 specific logic, so I would probably just have a larger switch with 2 completely separate code paths rather than multiple switch statements.

Comment on lines +1212 to +1219
if (has_rowwise_scale_inv) {
NVTE_CHECK(output->scale_inv.has_data(),
"Output tensor does not have row-wise scaling factors.");
}
if (has_columnwise_scale_inv) {
NVTE_CHECK(output->columnwise_scale_inv.has_data(),
"Output tensor does not have column-wise scaling factors.");
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would say that the logic here is a little backwards, even though I understand how here it is not obvious. Ultimately it is the output that tells you what to do in the function - think about the quantize function where the input does not know anything about the format to which it is quantized and it is the output that controls scaling mode and whether we need rowwise or columnwise quantization. Therefore here I would also treat the output as a "source of truth" on what we need to do and then check that the input tensor provides the right data (as opposed to this code which looks to input to know what to do and then checks the output).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Chaned this for single tensor. Let me know if that makes sense. Can you tell me how this would be called so that I can check the input and output and how they are allocated. Currently I am assuming from your comment above that the output would have all the necessary information to decide between rowwise, columnwise, scaling_mode and data pointers along with dimensions such as m and k. If this is fine then I can make these changes to multi tensor version as well.

Comment on lines +1279 to +1307
switch (vec_load_size) {
case 4:
NVTE_CHECK_CUDA(
cudaFuncSetAttribute(unswizzle_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size));
unswizzle_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(input_scale_inv_ptr,
output_scale_inv_ptr, m, k, true);
break;
case 2:
NVTE_CHECK_CUDA(
cudaFuncSetAttribute(unswizzle_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size));
unswizzle_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(input_scale_inv_ptr,
output_scale_inv_ptr, m, k, true);
break;
case 1:
NVTE_CHECK_CUDA(
cudaFuncSetAttribute(unswizzle_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size));
unswizzle_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(input_scale_inv_ptr,
output_scale_inv_ptr, m, k, true);
break;
default:
NVTE_ERROR("Not valid vec_load_size.");
break;
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code is repeated multiple times and the rowwise vs columnwise differs only by the arguments to the kernel, not the template arguments. I think it would be better to have something like

auto kernel = [vec_load_size]() {
switch (vec_load_size) {
      case 4:
        return unswizzle_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>;
      case 2:
        return unswizzle_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>;
      case 1:
        return unswizzle_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>;
      default:
        NVTE_ERROR("Not valid vec_load_size.");
        break;
    }
}{};

before the if (rowwise_unswizzle) and then using that inside the if statements.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed this. I am trying to keep as close to what the swizzle implementation is since thats what speaks to me of the coding style. Since a similar implementation in launch_multi_tensor_unswizzle_scaling_factors also exists let me know if that needs to be abstracted out as well. That wont be as clean as this one since there are two separate kernels in that case. I can have one lambda deciding the kernel name multi_tensor_unswizzle_row_scaling_kernel or multi_tensor_unswizzle_col_scaling_kernel and another to decide vector size or LType. Not sure if I am a fan of calling so many functions.

Signed-off-by: Abhishek <abhi.dtu11@gmail.com>
Comment on lines +1449 to +1461
kernel_args.block_range[0] = 0;
int vec_load_size = 4;
for (size_t i = 0; i < num_tensors; i++) {
if (kernel_args.num_tensors == kMaxTensorsPerKernel) {
if (vec_load_size == 3) vec_load_size = 1;
launch_multi_tensor_unswizzle_scaling_factors<SF_TILE_DIM_M, SF_TILE_DIM_K>(
kernel_args, vec_load_size, false, stream);
kernel_args.num_tensors = 0;
vec_load_size = 4;
}
const int m = input[i]->columnwise_scale_inv.shape[1];
const int k = input[i]->columnwise_scale_inv.shape[0];

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 original_m_list/original_k_list set but unused by unswizzle kernels

Inside multi_tensor_unswizzle_scaling_factors, both the rowwise path (lines ~955–956 and ~960–961) and the columnwise path (lines ~1006–1007) populate kernel_args.original_m_list[pos] and kernel_args.original_k_list[pos]. However, neither multi_tensor_unswizzle_row_scaling_kernel nor multi_tensor_unswizzle_col_scaling_kernel reads these fields — they only consume m_list and k_list. The swizzle kernels need the original (unpadded) dimensions to zero-fill padding, but the unswizzle kernels always operate on already-padded swizzled input and produce padded compact output, so no masking is required.

Setting unused struct fields is harmless today but adds noise and could mislead a reader into thinking the unswizzle kernels honour padding boundaries the same way the swizzle kernels do. Consider either removing these assignments or adding a comment explaining why they are intentionally populated (e.g., "kept for future per-element padding masking").

Comment on lines +951 to +966
<<<num_blocks, block_size, slm_size, stream>>>(kernel_args);
break;
case 1:
NVTE_CHECK_CUDA(cudaFuncSetAttribute(
multi_tensor_unswizzle_col_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size));
multi_tensor_unswizzle_col_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(kernel_args);
break;
default:
NVTE_ERROR("Not valid vec_load_size.");
break;
}
}
NVTE_CHECK_CUDA(cudaGetLastError());
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Shared memory load may OOB-read across M-tile boundaries in unswizzle_col_scaling_kernel_impl

The SLM load loop treats each M-tile's K-tiles as a flat, contiguous array:

for (int j = linear_id; j < SF_TILE_SIZE_I32 * k_tiles_in_tb / 4; j += blockDim.x * blockDim.y) {
    slm_v4i[j] = input_v4i[j];
}

input_v4i is derived from input_i32[i], which points to the start of M-tile i within the current K-tile group. In the GEMM-swizzled layout the K-tiles for a single M-tile are not stored contiguously: they are separated by SF_TILE_DIM_M_I32 * K_i32 int32 values (i.e., one full "column" stride). Reading SF_TILE_SIZE_I32 * k_tiles_in_tb contiguous int32s from that pointer therefore walks across unrelated data in memory once past the first K-tile, producing incorrect SLM contents for all but the first K-tile.

Compare with swizzle_col_scaling_kernel_impl, which reads input in compact (M-major) format where K-tiles for a given M-tile are contiguous in memory — that is why the flat input_i32 pointer arithmetic works there.

Please verify that the swizzled input layout actually stores all K-tiles for a given M-tile contiguously (one tile block per thread-block), or restructure the load to stride by the correct per-K-tile offset.

Comment on lines +249 to 263
std::vector<std::pair<size_t, size_t>> unswizzle_data_shapes = {
// Aligned: scale dims are already multiples of 128 and 4
{128, 128},
{128, 16896}, // K = 132 * 128, large K
{16896, 128}, // M = 132 * 128, large M
// M-padding only: M not a multiple of 128 (scale-M needs padding to 256)
{160, 128},
// scale-K padding only: K/32 = 3, padded to 4
{128, 96},
// Both M and scale-K need padding
{160, 96},
};

std::vector<std::pair<bool, bool>> scaling_mode = {
{true, false},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Roundtrip test only covers aligned matrix dimensions

performTestSwizzleUnswizzleRoundtrip is instantiated exclusively with the existing num_tiles vector, which always produces M = num_tiles_M * MAT_TILE_DIM_M — values that are exact multiples of 128 (the scale-M alignment). The standalone performTestUnswizzle1D intentionally adds padded shapes (e.g., M=160, K=96) via unswizzle_data_shapes, but no equivalent padded cases exist for the roundtrip.

If the output-size validation or padding-mask logic ever diverges between the swizzle and unswizzle paths for non-aligned M/K, the roundtrip test would pass while standalone tests fail (or vice-versa). Consider adding a few padded shapes (e.g., {4, 3} tile-count pairs or raw {160, 96} shapes) to num_tiles or creating a separate data-shape vector for the roundtrip suite.

Signed-off-by: Abhishek <abhi.dtu11@gmail.com>
Comment on lines +158 to +161
if (rowwise && columnwise) {
GTEST_SKIP() << "TEST SKIPPED, The scaling mode " + std::to_string(SF_MODE_X) + "x" +
std::to_string(SF_MODE_Y) + " is not implemented.";
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 UB: SF_MODE_X/SF_MODE_Y uninitialized when only one flag is true

When rowwise && columnwise is true, both if (rowwise) and if (columnwise) branches above execute so the variables are defined — but when exactly one of the two flags is set (the common test-suite case), only one branch runs and the other variable is never written. Referencing both in std::to_string(SF_MODE_X) + "x" + std::to_string(SF_MODE_Y) is undefined behaviour.

The same issue exists in performTestSwizzleUnswizzleRoundtrip at lines 323–326.

A safe fix initialises both variables at declaration and skips early when both are false:

  int SF_MODE_X = 0, SF_MODE_Y = 0;
  if (!rowwise && !columnwise) {
    GTEST_SKIP() << "TEST SKIPPED, Either rowwise or columnwise scaling mode must be true.";
  }
  if (rowwise)  { SF_MODE_X = 1;  SF_MODE_Y = 32; }
  if (columnwise) { SF_MODE_X = 32; SF_MODE_Y = 1;  }
  if (rowwise && columnwise) {
    GTEST_SKIP() << "TEST SKIPPED, The scaling mode " + std::to_string(SF_MODE_X) + "x" +
      std::to_string(SF_MODE_Y) + " is not implemented.";
  }

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants