diff --git a/monai/networks/blocks/patchembedding.py b/monai/networks/blocks/patchembedding.py index a4caae68be..cf85c2836f 100644 --- a/monai/networks/blocks/patchembedding.py +++ b/monai/networks/blocks/patchembedding.py @@ -29,6 +29,32 @@ SUPPORTED_POS_EMBEDDING_TYPES = {"none", "learnable", "sincos", "fourier"} +class _PatchRearrange(nn.Module): + """Fallback patch rearrangement using pure PyTorch, for einops compatibility.""" + + def __init__(self, spatial_dims: int, patch_size: tuple) -> None: + super().__init__() + self.spatial_dims = spatial_dims + self.patch_size = patch_size + + def forward(self, x: torch.Tensor) -> torch.Tensor: + batch, channels = x.shape[0], x.shape[1] + sp = x.shape[2:] + g = tuple(s // p for s, p in zip(sp, self.patch_size)) + v: list[int] = [batch, channels] + for gi, pi in zip(g, self.patch_size): + v += [gi, pi] + x = x.view(*v) + n = self.spatial_dims + gdims = list(range(2, 2 + 2 * n, 2)) + pdims = list(range(3, 3 + 2 * n, 2)) + x = x.permute(0, *gdims, *pdims, 1).contiguous() + n_patches = 1 + for gi in g: + n_patches *= gi + return x.reshape(batch, n_patches, -1) + + class PatchEmbeddingBlock(nn.Module): """ A patch embedding block, based on: "Dosovitskiy et al., @@ -97,14 +123,16 @@ def __init__( in_channels=in_channels, out_channels=hidden_size, kernel_size=patch_size, stride=patch_size ) elif self.proj_type == "perceptron": - # for 3d: "b c (h p1) (w p2) (d p3)-> b (h w d) (p1 p2 p3 c)" + # for 3d: "b c (h p1) (w p2) (d p3) -> b (h w d) (p1 p2 p3 c)" chars = (("h", "p1"), ("w", "p2"), ("d", "p3"))[:spatial_dims] from_chars = "b c " + " ".join(f"({k} {v})" for k, v in chars) to_chars = f"b ({' '.join([c[0] for c in chars])}) ({' '.join([c[1] for c in chars])} c)" axes_len = {f"p{i + 1}": p for i, p in enumerate(patch_size)} - self.patch_embeddings = nn.Sequential( - Rearrange(f"{from_chars} -> {to_chars}", **axes_len), nn.Linear(self.patch_dim, hidden_size) - ) + try: + rearrange_layer: nn.Module = Rearrange(f"{from_chars} -> {to_chars}", **axes_len) + except TypeError: + rearrange_layer = _PatchRearrange(spatial_dims, tuple(int(p) for p in patch_size)) + self.patch_embeddings = nn.Sequential(rearrange_layer, nn.Linear(self.patch_dim, hidden_size)) self.position_embeddings = nn.Parameter(torch.zeros(1, self.n_patches, hidden_size)) self.dropout = nn.Dropout(dropout_rate)