diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/_schema/_datastore/azure_storage.py b/sdk/ml/azure-ai-ml/azure/ai/ml/_schema/_datastore/azure_storage.py index 2989628ab679..3bb8e80da040 100644 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/_schema/_datastore/azure_storage.py +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/_schema/_datastore/azure_storage.py @@ -29,6 +29,8 @@ class AzureStorageSchema(PathAwareSchema): protocol = fields.Str() description = fields.Str() tags = fields.Dict(keys=fields.Str(), values=fields.Str()) + subscription_id = fields.Str() + resource_group = fields.Str() class AzureFileSchema(AzureStorageSchema): diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_datastore/azure_storage.py b/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_datastore/azure_storage.py index 0fff1925177a..297428d832ca 100644 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_datastore/azure_storage.py +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_datastore/azure_storage.py @@ -52,6 +52,10 @@ class AzureFileDatastore(Datastore): :param credentials: Credentials to use for Azure ML workspace to connect to the storage. Defaults to None. :type credentials: Union[~azure.ai.ml.entities.AccountKeyConfiguration, ~azure.ai.ml.entities.SasTokenConfiguration] + :param subscription_id: Azure subscription ID of the storage account. Defaults to None. + :type subscription_id: Optional[str] + :param resource_group: Azure resource group of the storage account. Defaults to None. + :type resource_group: Optional[str] :param kwargs: A dictionary of additional configuration parameters. :type kwargs: dict """ @@ -68,6 +72,8 @@ def __init__( protocol: str = HTTPS, properties: Optional[Dict] = None, credentials: Optional[Union[AccountKeyConfiguration, SasTokenConfiguration]] = None, + subscription_id: Optional[str] = None, + resource_group: Optional[str] = None, **kwargs: Any ): kwargs[TYPE] = DatastoreType.AZURE_FILE @@ -78,6 +84,8 @@ def __init__( self.account_name = account_name self.endpoint = endpoint self.protocol = protocol + self.subscription_id = subscription_id + self.resource_group = resource_group def _to_rest_object(self) -> DatastoreData: file_ds = RestAzureFileDatastore( @@ -88,6 +96,8 @@ def _to_rest_object(self) -> DatastoreData: protocol=self.protocol, description=self.description, tags=self.tags, + subscription_id=self.subscription_id, + resource_group=self.resource_group, ) return DatastoreData(properties=file_ds) @@ -109,6 +119,8 @@ def _from_rest_object(cls, datastore_resource: DatastoreData) -> "AzureFileDatas file_share_name=properties.file_share_name, description=properties.description, tags=properties.tags, + subscription_id=properties.subscription_id, + resource_group=properties.resource_group, ) def __eq__(self, other: Any) -> bool: @@ -152,6 +164,10 @@ class AzureBlobDatastore(Datastore): :param credentials: Credentials to use for Azure ML workspace to connect to the storage. :type credentials: Union[~azure.ai.ml.entities.AccountKeyConfiguration, ~azure.ai.ml.entities.SasTokenConfiguration] + :param subscription_id: Azure subscription ID of the storage account. Defaults to None. + :type subscription_id: Optional[str] + :param resource_group: Azure resource group of the storage account. Defaults to None. + :type resource_group: Optional[str] :param kwargs: A dictionary of additional configuration parameters. :type kwargs: dict """ @@ -168,6 +184,8 @@ def __init__( protocol: str = HTTPS, properties: Optional[Dict] = None, credentials: Optional[Union[AccountKeyConfiguration, SasTokenConfiguration]] = None, + subscription_id: Optional[str] = None, + resource_group: Optional[str] = None, **kwargs: Any ): kwargs[TYPE] = DatastoreType.AZURE_BLOB @@ -179,6 +197,8 @@ def __init__( self.account_name = account_name self.endpoint = endpoint if endpoint else _get_storage_endpoint_from_metadata() self.protocol = protocol + self.subscription_id = subscription_id + self.resource_group = resource_group def _to_rest_object(self) -> DatastoreData: blob_ds = RestAzureBlobDatastore( @@ -189,6 +209,8 @@ def _to_rest_object(self) -> DatastoreData: protocol=self.protocol, tags=self.tags, description=self.description, + subscription_id=self.subscription_id, + resource_group=self.resource_group, ) return DatastoreData(properties=blob_ds) @@ -210,6 +232,8 @@ def _from_rest_object(cls, datastore_resource: DatastoreData) -> "AzureBlobDatas container_name=properties.container_name, description=properties.description, tags=properties.tags, + subscription_id=properties.subscription_id, + resource_group=properties.resource_group, ) def __eq__(self, other: Any) -> bool: @@ -256,6 +280,10 @@ class AzureDataLakeGen2Datastore(Datastore): ] :param properties: The asset property dictionary. :type properties: dict[str, str] + :param subscription_id: Azure subscription ID of the storage account. Defaults to None. + :type subscription_id: Optional[str] + :param resource_group: Azure resource group of the storage account. Defaults to None. + :type resource_group: Optional[str] :param kwargs: A dictionary of additional configuration parameters. :type kwargs: dict """ @@ -272,6 +300,8 @@ def __init__( protocol: str = HTTPS, properties: Optional[Dict] = None, credentials: Optional[Union[ServicePrincipalConfiguration, CertificateConfiguration]] = None, + subscription_id: Optional[str] = None, + resource_group: Optional[str] = None, **kwargs: Any ): kwargs[TYPE] = DatastoreType.AZURE_DATA_LAKE_GEN2 @@ -283,6 +313,8 @@ def __init__( self.filesystem = filesystem self.endpoint = endpoint self.protocol = protocol + self.subscription_id = subscription_id + self.resource_group = resource_group def _to_rest_object(self) -> DatastoreData: gen2_ds = RestAzureDataLakeGen2Datastore( @@ -293,6 +325,8 @@ def _to_rest_object(self) -> DatastoreData: protocol=self.protocol, description=self.description, tags=self.tags, + subscription_id=self.subscription_id, + resource_group=self.resource_group, ) return DatastoreData(properties=gen2_ds) @@ -316,6 +350,8 @@ def _from_rest_object(cls, datastore_resource: DatastoreData) -> "AzureDataLakeG filesystem=properties.filesystem, description=properties.description, tags=properties.tags, + subscription_id=properties.subscription_id, + resource_group=properties.resource_group, ) def __eq__(self, other: Any) -> bool: