diff --git a/sagemaker-train/src/sagemaker/train/tuner.py b/sagemaker-train/src/sagemaker/train/tuner.py index cde1598481..8b05b9e9e1 100644 --- a/sagemaker-train/src/sagemaker/train/tuner.py +++ b/sagemaker-train/src/sagemaker/train/tuner.py @@ -12,7 +12,7 @@ # language governing permissions and limitations under the License. """Placeholder docstring""" -from __future__ import absolute_import +from __future__ import absolute_import, annotations import logging from enum import Enum @@ -442,6 +442,27 @@ def _prepare_auto_parameters(self, static_hyperparameters, hyperparameters_to_ke return new_static_hyperparameters, auto_parameters + @staticmethod + def _get_model_trainer_environment( + model_trainer: "ModelTrainer", + ) -> dict[str, str] | None: + """Extract environment variables from a ModelTrainer instance. + + Returns a copy of the environment dict if it is non-empty, + otherwise None. + + Args: + model_trainer (ModelTrainer): ModelTrainer instance. + + Returns: + dict[str, str] | None: A copy of the environment variables + dict, or None if empty/not set. + """ + env = model_trainer.environment + if env: + return dict(env) + return None + @classmethod def _prepare_model_trainer_for_tuning(cls, model_trainer, inputs=None, job_name=None, **kwargs): """Prepare ModelTrainer before tuning by building sm_drivers and code channels. @@ -1513,8 +1534,8 @@ def _build_training_job_definition(self, inputs): ) # Pass through environment variables from model_trainer - env = getattr(model_trainer, "environment", None) - if env and isinstance(env, dict): + env = self._get_model_trainer_environment(model_trainer) + if env is not None: definition.environment = env # Pass through VPC config from model_trainer diff --git a/sagemaker-train/tests/unit/train/test_tuner.py b/sagemaker-train/tests/unit/train/test_tuner.py index c0255eac47..1880759809 100644 --- a/sagemaker-train/tests/unit/train/test_tuner.py +++ b/sagemaker-train/tests/unit/train/test_tuner.py @@ -39,12 +39,13 @@ # --------------------------------------------------------------------------- -def _create_mock_model_trainer(with_internal_channels=False): +def _create_mock_model_trainer(with_internal_channels=False, environment=None): """Create a mock ModelTrainer with common attributes. Args: with_internal_channels: If True, adds internal channels (code, sm_drivers) to input_data_config for testing channel inclusion in tuning jobs. + environment: Optional dict of environment variables to set on the trainer. """ trainer = MagicMock() trainer.sagemaker_session = MagicMock() @@ -61,6 +62,7 @@ def _create_mock_model_trainer(with_internal_channels=False): trainer.stopping_condition = MagicMock() trainer.stopping_condition.max_runtime_in_seconds = 3600 trainer.input_data_config = None + trainer.environment = environment if environment is not None else {} if with_internal_channels: trainer.input_data_config = [ @@ -574,3 +576,200 @@ def test_build_training_job_definition_includes_internal_channels(self): assert "train" in channel_names, "User 'train' channel should be included" assert "validation" in channel_names, "User 'validation' channel should be included" assert len(channel_names) == 4, "Should have exactly 4 channels" + + def test_build_training_job_definition_includes_environment_variables(self): + """Test that _build_training_job_definition includes env vars. + + This test verifies the fix for GitHub issue #5613 where tuning + jobs were missing environment variables set on the ModelTrainer. + """ + env_vars = {"RANDOM_STATE": "42", "MY_VAR": "hello"} + mock_trainer = _create_mock_model_trainer( + environment=env_vars, + ) + + tuner = HyperparameterTuner( + model_trainer=mock_trainer, + objective_metric_name="accuracy", + hyperparameter_ranges=_create_single_hp_range(), + ) + + definition = tuner._build_training_job_definition(None) + + # The definition should contain the environment variables + assert definition.environment == env_vars, ( + f"Environment should be {env_vars}, " + f"got {definition.environment}" + ) + # Verify defensive copy: the dict on the definition + # should not be the same object as the original + assert definition.environment is not env_vars, ( + "Environment should be a copy, not the same object" + ) + + def test_build_training_job_definition_with_empty_environment(self): + """Test that empty env is not propagated to definition.""" + mock_trainer = _create_mock_model_trainer(environment={}) + + tuner = HyperparameterTuner( + model_trainer=mock_trainer, + objective_metric_name="accuracy", + hyperparameter_ranges=_create_single_hp_range(), + ) + + definition = tuner._build_training_job_definition(None) + assert definition is not None + # Empty environment should not be set on the definition + env = getattr(definition, "environment", None) + assert env is None, ( + "Empty environment should not be propagated, " + f"got {env}" + ) + + def test_build_training_job_definition_with_none_environment(self): + """Test that None env is not propagated to definition.""" + mock_trainer = _create_mock_model_trainer() + mock_trainer.environment = None + + tuner = HyperparameterTuner( + model_trainer=mock_trainer, + objective_metric_name="accuracy", + hyperparameter_ranges=_create_single_hp_range(), + ) + + definition = tuner._build_training_job_definition(None) + assert definition is not None + # None environment should not be set on the definition + env = getattr(definition, "environment", None) + assert env is None, ( + "None environment should not be propagated, " + f"got {env}" + ) + + +class TestGetModelTrainerEnvironment: + """Test _get_model_trainer_environment helper method.""" + + def test_returns_environment_when_set(self): + """Test that environment is returned when set.""" + env_vars = {"KEY1": "val1", "KEY2": "val2"} + mock_trainer = _create_mock_model_trainer( + environment=env_vars, + ) + + result = HyperparameterTuner._get_model_trainer_environment( + mock_trainer, + ) + assert result == env_vars + # Verify it's a copy, not the same object + assert result is not env_vars, ( + "Should return a defensive copy" + ) + + def test_returns_none_when_empty(self): + """Test that None is returned when environment is empty.""" + mock_trainer = _create_mock_model_trainer(environment={}) + + result = HyperparameterTuner._get_model_trainer_environment( + mock_trainer, + ) + assert result is None + + def test_returns_none_when_none(self): + """Test that None is returned when environment is None.""" + mock_trainer = _create_mock_model_trainer() + mock_trainer.environment = None + + result = HyperparameterTuner._get_model_trainer_environment( + mock_trainer, + ) + assert result is None + + +class TestMultiTrainerEnvironmentPropagation: + """Test environment propagation for multi-trainer tuning jobs.""" + + def test_create_multi_trainer_with_environment(self): + """Test that environment is preserved on trainers in create().""" + env1 = {"VAR_A": "1"} + env2 = {"VAR_B": "2"} + trainer1 = _create_mock_model_trainer(environment=env1) + trainer2 = _create_mock_model_trainer(environment=env2) + + tuner = HyperparameterTuner.create( + model_trainer_dict={ + "trainer1": trainer1, + "trainer2": trainer2, + }, + objective_metric_name_dict={ + "trainer1": "accuracy", + "trainer2": "loss", + }, + hyperparameter_ranges_dict={ + "trainer1": _create_single_hp_range(), + "trainer2": _create_single_hp_range(), + }, + ) + + # Verify environment is preserved on each trainer + assert tuner.model_trainer_dict["trainer1"].environment == env1 + assert tuner.model_trainer_dict["trainer2"].environment == env2 + + def test_get_environment_for_each_trainer_in_dict(self): + """Test _get_model_trainer_environment for each trainer.""" + env1 = {"VAR_A": "1"} + env2 = {"VAR_B": "2"} + trainer1 = _create_mock_model_trainer(environment=env1) + trainer2 = _create_mock_model_trainer(environment=env2) + + tuner = HyperparameterTuner.create( + model_trainer_dict={ + "trainer1": trainer1, + "trainer2": trainer2, + }, + objective_metric_name_dict={ + "trainer1": "accuracy", + "trainer2": "loss", + }, + hyperparameter_ranges_dict={ + "trainer1": _create_single_hp_range(), + "trainer2": _create_single_hp_range(), + }, + ) + + for name, mt in tuner.model_trainer_dict.items(): + env = HyperparameterTuner._get_model_trainer_environment( + mt, + ) + if name == "trainer1": + assert env == env1 + elif name == "trainer2": + assert env == env2 + + def test_multi_trainer_empty_environment(self): + """Test multi-trainer with empty environment.""" + trainer1 = _create_mock_model_trainer(environment={}) + trainer2 = _create_mock_model_trainer(environment={}) + + tuner = HyperparameterTuner.create( + model_trainer_dict={ + "trainer1": trainer1, + "trainer2": trainer2, + }, + objective_metric_name_dict={ + "trainer1": "accuracy", + "trainer2": "loss", + }, + hyperparameter_ranges_dict={ + "trainer1": _create_single_hp_range(), + "trainer2": _create_single_hp_range(), + }, + ) + + for _name, mt in tuner.model_trainer_dict.items(): + env = HyperparameterTuner._get_model_trainer_environment( + mt, + ) + assert env is None, ( + "Empty environment should return None" + )