From 78756f808b43af5601b06780d73d134623d2d3d3 Mon Sep 17 00:00:00 2001 From: UGBOMEH OGOCHUKWU WILLIAMS Date: Sat, 18 Apr 2026 14:33:58 +0100 Subject: [PATCH 1/3] Fix #8462: embed patch sizes in einops pattern for einops >= 0.8 compatibility einops 0.8.x removed support for arbitrary kwargs in Rearrange.__init__(). Replace axes_len dict and **axes_len kwarg with integer literals embedded directly in the pattern string. Semantically identical, compatible with all einops versions. Signed-off-by: UGBOMEH OGOCHUKWU WILLIAMS --- monai/networks/blocks/patchembedding.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/monai/networks/blocks/patchembedding.py b/monai/networks/blocks/patchembedding.py index a4caae68be..b7517363cd 100644 --- a/monai/networks/blocks/patchembedding.py +++ b/monai/networks/blocks/patchembedding.py @@ -97,13 +97,12 @@ def __init__( in_channels=in_channels, out_channels=hidden_size, kernel_size=patch_size, stride=patch_size ) elif self.proj_type == "perceptron": - # for 3d: "b c (h p1) (w p2) (d p3)-> b (h w d) (p1 p2 p3 c)" - chars = (("h", "p1"), ("w", "p2"), ("d", "p3"))[:spatial_dims] - from_chars = "b c " + " ".join(f"({k} {v})" for k, v in chars) - to_chars = f"b ({' '.join([c[0] for c in chars])}) ({' '.join([c[1] for c in chars])} c)" - axes_len = {f"p{i + 1}": p for i, p in enumerate(patch_size)} + # for 3d: "b c (h 16) (w 16) (d 16) -> b (h w d) (16 16 16 c)" + dim_names = ("h", "w", "d")[:spatial_dims] + from_chars = "b c " + " ".join(f"({name} {psize})" for name, psize in zip(dim_names, patch_size)) + to_chars = f"b ({' '.join(dim_names)}) ({' '.join(str(p) for p in patch_size)} c)" self.patch_embeddings = nn.Sequential( - Rearrange(f"{from_chars} -> {to_chars}", **axes_len), nn.Linear(self.patch_dim, hidden_size) + Rearrange(f"{from_chars} -> {to_chars}"), nn.Linear(self.patch_dim, hidden_size) ) self.position_embeddings = nn.Parameter(torch.zeros(1, self.n_patches, hidden_size)) self.dropout = nn.Dropout(dropout_rate) From 093f0b04bbd937b728c31405d6ee3ba2bbe3cf07 Mon Sep 17 00:00:00 2001 From: UGBOMEH OGOCHUKWU WILLIAMS Date: Sat, 18 Apr 2026 17:44:08 +0100 Subject: [PATCH 2/3] Fix #8462: handle einops >= 0.8 Rearrange kwargs API change In some einops 0.8.x builds, Rearrange.__init__() does not accept **kwargs for axis sizes, raising TypeError. Add _PatchRearrange as a pure-PyTorch fallback that produces identical output. The primary path still uses Rearrange with named axes (**axes_len); the fallback only activates if that call raises TypeError. Signed-off-by: UGBOMEH OGOCHUKWU WILLIAMS --- monai/networks/blocks/patchembedding.py | 43 +++++++++++++++++++++---- 1 file changed, 36 insertions(+), 7 deletions(-) diff --git a/monai/networks/blocks/patchembedding.py b/monai/networks/blocks/patchembedding.py index b7517363cd..e033c49ab2 100644 --- a/monai/networks/blocks/patchembedding.py +++ b/monai/networks/blocks/patchembedding.py @@ -29,6 +29,32 @@ SUPPORTED_POS_EMBEDDING_TYPES = {"none", "learnable", "sincos", "fourier"} +class _PatchRearrange(nn.Module): + """Fallback patch rearrangement using pure PyTorch, for einops compatibility.""" + + def __init__(self, spatial_dims: int, patch_size: tuple) -> None: + super().__init__() + self.spatial_dims = spatial_dims + self.patch_size = patch_size + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, C = x.shape[0], x.shape[1] + sp = x.shape[2:] + g = tuple(s // p for s, p in zip(sp, self.patch_size)) + v: list[int] = [B, C] + for gi, pi in zip(g, self.patch_size): + v += [gi, pi] + x = x.view(*v) + n = self.spatial_dims + gdims = list(range(2, 2 + 2 * n, 2)) + pdims = list(range(3, 3 + 2 * n, 2)) + x = x.permute(0, *gdims, *pdims, 1).contiguous() + n_patches = 1 + for gi in g: + n_patches *= gi + return x.reshape(B, n_patches, -1) + + class PatchEmbeddingBlock(nn.Module): """ A patch embedding block, based on: "Dosovitskiy et al., @@ -97,13 +123,16 @@ def __init__( in_channels=in_channels, out_channels=hidden_size, kernel_size=patch_size, stride=patch_size ) elif self.proj_type == "perceptron": - # for 3d: "b c (h 16) (w 16) (d 16) -> b (h w d) (16 16 16 c)" - dim_names = ("h", "w", "d")[:spatial_dims] - from_chars = "b c " + " ".join(f"({name} {psize})" for name, psize in zip(dim_names, patch_size)) - to_chars = f"b ({' '.join(dim_names)}) ({' '.join(str(p) for p in patch_size)} c)" - self.patch_embeddings = nn.Sequential( - Rearrange(f"{from_chars} -> {to_chars}"), nn.Linear(self.patch_dim, hidden_size) - ) + # for 3d: "b c (h p1) (w p2) (d p3) -> b (h w d) (p1 p2 p3 c)" + chars = (("h", "p1"), ("w", "p2"), ("d", "p3"))[:spatial_dims] + from_chars = "b c " + " ".join(f"({k} {v})" for k, v in chars) + to_chars = f"b ({' '.join([c[0] for c in chars])}) ({' '.join([c[1] for c in chars])} c)" + axes_len = {f"p{i + 1}": p for i, p in enumerate(patch_size)} + try: + rearrange_layer: nn.Module = Rearrange(f"{from_chars} -> {to_chars}", **axes_len) + except TypeError: + rearrange_layer = _PatchRearrange(spatial_dims, tuple(int(p) for p in patch_size)) + self.patch_embeddings = nn.Sequential(rearrange_layer, nn.Linear(self.patch_dim, hidden_size)) self.position_embeddings = nn.Parameter(torch.zeros(1, self.n_patches, hidden_size)) self.dropout = nn.Dropout(dropout_rate) From f1cffe8d1801f212add2144d1e0b2286b0cd8883 Mon Sep 17 00:00:00 2001 From: UGBOMEH OGOCHUKWU WILLIAMS Date: Sat, 18 Apr 2026 17:59:21 +0100 Subject: [PATCH 3/3] Fix ruff N806: rename uppercase vars B, C to batch, channels Signed-off-by: UGBOMEH OGOCHUKWU WILLIAMS --- monai/networks/blocks/patchembedding.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/networks/blocks/patchembedding.py b/monai/networks/blocks/patchembedding.py index e033c49ab2..cf85c2836f 100644 --- a/monai/networks/blocks/patchembedding.py +++ b/monai/networks/blocks/patchembedding.py @@ -38,10 +38,10 @@ def __init__(self, spatial_dims: int, patch_size: tuple) -> None: self.patch_size = patch_size def forward(self, x: torch.Tensor) -> torch.Tensor: - B, C = x.shape[0], x.shape[1] + batch, channels = x.shape[0], x.shape[1] sp = x.shape[2:] g = tuple(s // p for s, p in zip(sp, self.patch_size)) - v: list[int] = [B, C] + v: list[int] = [batch, channels] for gi, pi in zip(g, self.patch_size): v += [gi, pi] x = x.view(*v) @@ -52,7 +52,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: n_patches = 1 for gi in g: n_patches *= gi - return x.reshape(B, n_patches, -1) + return x.reshape(batch, n_patches, -1) class PatchEmbeddingBlock(nn.Module):