-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Fix #8462: embed patch sizes in einops pattern for einops >= 0.8 compatibility #8834
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -29,6 +29,32 @@ | |
| SUPPORTED_POS_EMBEDDING_TYPES = {"none", "learnable", "sincos", "fourier"} | ||
|
|
||
|
|
||
| class _PatchRearrange(nn.Module): | ||
| """Fallback patch rearrangement using pure PyTorch, for einops compatibility.""" | ||
|
|
||
| def __init__(self, spatial_dims: int, patch_size: tuple) -> None: | ||
| super().__init__() | ||
| self.spatial_dims = spatial_dims | ||
| self.patch_size = patch_size | ||
|
|
||
| def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
| batch, channels = x.shape[0], x.shape[1] | ||
| sp = x.shape[2:] | ||
| g = tuple(s // p for s, p in zip(sp, self.patch_size)) | ||
| v: list[int] = [batch, channels] | ||
| for gi, pi in zip(g, self.patch_size): | ||
| v += [gi, pi] | ||
| x = x.view(*v) | ||
| n = self.spatial_dims | ||
| gdims = list(range(2, 2 + 2 * n, 2)) | ||
| pdims = list(range(3, 3 + 2 * n, 2)) | ||
| x = x.permute(0, *gdims, *pdims, 1).contiguous() | ||
| n_patches = 1 | ||
| for gi in g: | ||
| n_patches *= gi | ||
| return x.reshape(batch, n_patches, -1) | ||
|
|
||
|
|
||
| class PatchEmbeddingBlock(nn.Module): | ||
| """ | ||
| A patch embedding block, based on: "Dosovitskiy et al., | ||
|
|
@@ -97,14 +123,16 @@ def __init__( | |
| in_channels=in_channels, out_channels=hidden_size, kernel_size=patch_size, stride=patch_size | ||
| ) | ||
| elif self.proj_type == "perceptron": | ||
| # for 3d: "b c (h p1) (w p2) (d p3)-> b (h w d) (p1 p2 p3 c)" | ||
| # for 3d: "b c (h p1) (w p2) (d p3) -> b (h w d) (p1 p2 p3 c)" | ||
| chars = (("h", "p1"), ("w", "p2"), ("d", "p3"))[:spatial_dims] | ||
| from_chars = "b c " + " ".join(f"({k} {v})" for k, v in chars) | ||
| to_chars = f"b ({' '.join([c[0] for c in chars])}) ({' '.join([c[1] for c in chars])} c)" | ||
| axes_len = {f"p{i + 1}": p for i, p in enumerate(patch_size)} | ||
| self.patch_embeddings = nn.Sequential( | ||
| Rearrange(f"{from_chars} -> {to_chars}", **axes_len), nn.Linear(self.patch_dim, hidden_size) | ||
| ) | ||
| try: | ||
| rearrange_layer: nn.Module = Rearrange(f"{from_chars} -> {to_chars}", **axes_len) | ||
| except TypeError: | ||
| rearrange_layer = _PatchRearrange(spatial_dims, tuple(int(p) for p in patch_size)) | ||
| self.patch_embeddings = nn.Sequential(rearrange_layer, nn.Linear(self.patch_dim, hidden_size)) | ||
|
Comment on lines
+131
to
+135
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fallback path isn't deterministically covered by tests.
As per coding guidelines: "Ensure new or modified definitions will be covered by existing or new unit tests." 🤖 Prompt for AI Agents |
||
| self.position_embeddings = nn.Parameter(torch.zeros(1, self.n_patches, hidden_size)) | ||
| self.dropout = nn.Dropout(dropout_rate) | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🌐 Web query:
einops Rearrange pattern integer literals support 0.8💡 Result:
No, einops Rearrange patterns do not support standalone integer literals (e.g., 'b 2 h w') as axis specifications. Patterns use symbolic axis names (letters like 'b', 'h'), anonymous axes (numbers in parentheses like '(b1 b2)'), ellipsis (...), or underscores (_ for skipping). Integer values are provided via the axes_lengths keyword argument (e.g., rearrange(x, 'b (h1 h) w c -> ...', h1=2)) or parsed from tensor shapes via parse_shape. This syntax has been consistent across versions, including 0.8.0+ (released 2024-04-28), with no changes introducing integer literal support in patterns per release notes, changelogs, and docs. Examples in official documentation and tutorials confirm named axes only.
Citations:
🏁 Script executed:
Repository: Project-MONAI/MONAI
Length of output: 4640
🏁 Script executed:
Repository: Project-MONAI/MONAI
Length of output: 130
🏁 Script executed:
Repository: Project-MONAI/MONAI
Length of output: 958
🏁 Script executed:
# Check the test file for coverage of _PatchRearrange and the fallback path cat monai/tests/networks/blocks/test_patchembedding.pyRepository: Project-MONAI/MONAI
Length of output: 145
🏁 Script executed:
Repository: Project-MONAI/MONAI
Length of output: 109
🏁 Script executed:
Repository: Project-MONAI/MONAI
Length of output: 275
🏁 Script executed:
# Read the test file cat ./tests/networks/blocks/test_patchembedding.pyRepository: Project-MONAI/MONAI
Length of output: 7510
🏁 Script executed:
# Check einops version constraints and any comments about the fallback logic head -30 monai/networks/blocks/patchembedding.pyRepository: Project-MONAI/MONAI
Length of output: 1308
🏁 Script executed:
Repository: Project-MONAI/MONAI
Length of output: 233
Retract the simplification suggestion; the current try/except approach is necessary.
einops Rearrange does not support integer literals in patterns—numeric axis values must be passed via
axes_lengths(the currentaxes_lendict). The suggested approach of embedding integers as literals (e.g.,"b c (h 16) (w 16) (d 16)") is not feasible. The fallback with_PatchRearrangeis the correct strategy for version compatibility.However, address these remaining issues in
_PatchRearrange:__init__andforwardmethods describing arguments, return values, and behavior per coding guidelines.patch_size: tupletopatch_size: tuple[int, ...].reshape()instead ofview(): Line 47 usesx.view(*v), which fails on non-contiguous tensors;reshape()is safer.Rearrangepath (since einops is installed). The fallback is never deterministically validated. Add a test that directly instantiates and tests_PatchRearrangeindependently.🤖 Prompt for AI Agents