-
Notifications
You must be signed in to change notification settings - Fork 1.2k
fix: ModelTrainer and HyperparameterTuner missing environment variables (5613) #5695
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,7 +12,7 @@ | |
| # language governing permissions and limitations under the License. | ||
| """Placeholder docstring""" | ||
|
|
||
| from __future__ import absolute_import | ||
| from __future__ import absolute_import, annotations | ||
|
|
||
| import logging | ||
| from enum import Enum | ||
|
|
@@ -442,6 +442,27 @@ def _prepare_auto_parameters(self, static_hyperparameters, hyperparameters_to_ke | |
|
|
||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Minor design question: This static method is essentially a 4-line null-safe copy. It's only called in one place ( |
||
| return new_static_hyperparameters, auto_parameters | ||
aviruthen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| @staticmethod | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The type annotation uses a forward reference string from typing import TYPE_CHECKING
if TYPE_CHECKING:
from sagemaker.train.model_trainer import ModelTrainerThis avoids a runtime import cycle while enabling static analysis tools to resolve the type. |
||
| def _get_model_trainer_environment( | ||
| model_trainer: "ModelTrainer", | ||
| ) -> dict[str, str] | None: | ||
| """Extract environment variables from a ModelTrainer instance. | ||
|
|
||
| Returns a copy of the environment dict if it is non-empty, | ||
| otherwise None. | ||
|
|
||
| Args: | ||
| model_trainer (ModelTrainer): ModelTrainer instance. | ||
|
|
||
| Returns: | ||
| dict[str, str] | None: A copy of the environment variables | ||
| dict, or None if empty/not set. | ||
| """ | ||
aviruthen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| env = model_trainer.environment | ||
aviruthen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| if env: | ||
| return dict(env) | ||
| return None | ||
|
|
||
| @classmethod | ||
| def _prepare_model_trainer_for_tuning(cls, model_trainer, inputs=None, job_name=None, **kwargs): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The PR description states: "Both the single-trainer path ( |
||
| """Prepare ModelTrainer before tuning by building sm_drivers and code channels. | ||
aviruthen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
@@ -1513,8 +1534,8 @@ def _build_training_job_definition(self, inputs): | |
| ) | ||
|
|
||
| # Pass through environment variables from model_trainer | ||
| env = getattr(model_trainer, "environment", None) | ||
| if env and isinstance(env, dict): | ||
| env = self._get_model_trainer_environment(model_trainer) | ||
aviruthen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| if env is not None: | ||
| definition.environment = env | ||
|
|
||
| # Pass through VPC config from model_trainer | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -39,12 +39,13 @@ | |
| # --------------------------------------------------------------------------- | ||
|
|
||
|
|
||
| def _create_mock_model_trainer(with_internal_channels=False): | ||
| def _create_mock_model_trainer(with_internal_channels=False, environment=None): | ||
| """Create a mock ModelTrainer with common attributes. | ||
|
|
||
| Args: | ||
| with_internal_channels: If True, adds internal channels (code, sm_drivers) | ||
| to input_data_config for testing channel inclusion in tuning jobs. | ||
| environment: Optional dict of environment variables to set on the trainer. | ||
| """ | ||
| trainer = MagicMock() | ||
| trainer.sagemaker_session = MagicMock() | ||
|
|
@@ -61,6 +62,7 @@ def _create_mock_model_trainer(with_internal_channels=False): | |
| trainer.stopping_condition = MagicMock() | ||
aviruthen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| trainer.stopping_condition.max_runtime_in_seconds = 3600 | ||
| trainer.input_data_config = None | ||
| trainer.environment = environment if environment is not None else {} | ||
|
|
||
| if with_internal_channels: | ||
| trainer.input_data_config = [ | ||
|
|
@@ -574,3 +576,200 @@ def test_build_training_job_definition_includes_internal_channels(self): | |
| assert "train" in channel_names, "User 'train' channel should be included" | ||
aviruthen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| assert "validation" in channel_names, "User 'validation' channel should be included" | ||
| assert len(channel_names) == 4, "Should have exactly 4 channels" | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good test — this directly validates the fix for issue #5613 by calling |
||
|
|
||
| def test_build_training_job_definition_includes_environment_variables(self): | ||
| """Test that _build_training_job_definition includes env vars. | ||
|
|
||
| This test verifies the fix for GitHub issue #5613 where tuning | ||
| jobs were missing environment variables set on the ModelTrainer. | ||
| """ | ||
| env_vars = {"RANDOM_STATE": "42", "MY_VAR": "hello"} | ||
| mock_trainer = _create_mock_model_trainer( | ||
| environment=env_vars, | ||
| ) | ||
|
|
||
| tuner = HyperparameterTuner( | ||
| model_trainer=mock_trainer, | ||
| objective_metric_name="accuracy", | ||
| hyperparameter_ranges=_create_single_hp_range(), | ||
| ) | ||
|
|
||
aviruthen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| definition = tuner._build_training_job_definition(None) | ||
|
|
||
aviruthen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| # The definition should contain the environment variables | ||
aviruthen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| assert definition.environment == env_vars, ( | ||
| f"Environment should be {env_vars}, " | ||
| f"got {definition.environment}" | ||
| ) | ||
| # Verify defensive copy: the dict on the definition | ||
| # should not be the same object as the original | ||
| assert definition.environment is not env_vars, ( | ||
| "Environment should be a copy, not the same object" | ||
| ) | ||
|
|
||
| def test_build_training_job_definition_with_empty_environment(self): | ||
| """Test that empty env is not propagated to definition.""" | ||
| mock_trainer = _create_mock_model_trainer(environment={}) | ||
|
|
||
| tuner = HyperparameterTuner( | ||
| model_trainer=mock_trainer, | ||
| objective_metric_name="accuracy", | ||
| hyperparameter_ranges=_create_single_hp_range(), | ||
| ) | ||
|
|
||
| definition = tuner._build_training_job_definition(None) | ||
| assert definition is not None | ||
| # Empty environment should not be set on the definition | ||
| env = getattr(definition, "environment", None) | ||
| assert env is None, ( | ||
| "Empty environment should not be propagated, " | ||
| f"got {env}" | ||
| ) | ||
|
|
||
| def test_build_training_job_definition_with_none_environment(self): | ||
| """Test that None env is not propagated to definition.""" | ||
| mock_trainer = _create_mock_model_trainer() | ||
| mock_trainer.environment = None | ||
|
|
||
| tuner = HyperparameterTuner( | ||
| model_trainer=mock_trainer, | ||
| objective_metric_name="accuracy", | ||
| hyperparameter_ranges=_create_single_hp_range(), | ||
| ) | ||
|
|
||
aviruthen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| definition = tuner._build_training_job_definition(None) | ||
| assert definition is not None | ||
| # None environment should not be set on the definition | ||
| env = getattr(definition, "environment", None) | ||
| assert env is None, ( | ||
| "None environment should not be propagated, " | ||
| f"got {env}" | ||
| ) | ||
|
|
||
|
|
||
| class TestGetModelTrainerEnvironment: | ||
| """Test _get_model_trainer_environment helper method.""" | ||
|
|
||
| def test_returns_environment_when_set(self): | ||
| """Test that environment is returned when set.""" | ||
| env_vars = {"KEY1": "val1", "KEY2": "val2"} | ||
| mock_trainer = _create_mock_model_trainer( | ||
| environment=env_vars, | ||
| ) | ||
|
|
||
| result = HyperparameterTuner._get_model_trainer_environment( | ||
| mock_trainer, | ||
| ) | ||
| assert result == env_vars | ||
| # Verify it's a copy, not the same object | ||
| assert result is not env_vars, ( | ||
| "Should return a defensive copy" | ||
| ) | ||
|
|
||
| def test_returns_none_when_empty(self): | ||
| """Test that None is returned when environment is empty.""" | ||
| mock_trainer = _create_mock_model_trainer(environment={}) | ||
|
|
||
| result = HyperparameterTuner._get_model_trainer_environment( | ||
| mock_trainer, | ||
| ) | ||
| assert result is None | ||
|
|
||
| def test_returns_none_when_none(self): | ||
| """Test that None is returned when environment is None.""" | ||
| mock_trainer = _create_mock_model_trainer() | ||
| mock_trainer.environment = None | ||
|
|
||
| result = HyperparameterTuner._get_model_trainer_environment( | ||
| mock_trainer, | ||
| ) | ||
| assert result is None | ||
|
|
||
|
|
||
| class TestMultiTrainerEnvironmentPropagation: | ||
| """Test environment propagation for multi-trainer tuning jobs.""" | ||
|
|
||
| def test_create_multi_trainer_with_environment(self): | ||
| """Test that environment is preserved on trainers in create().""" | ||
| env1 = {"VAR_A": "1"} | ||
| env2 = {"VAR_B": "2"} | ||
| trainer1 = _create_mock_model_trainer(environment=env1) | ||
| trainer2 = _create_mock_model_trainer(environment=env2) | ||
|
|
||
| tuner = HyperparameterTuner.create( | ||
| model_trainer_dict={ | ||
| "trainer1": trainer1, | ||
| "trainer2": trainer2, | ||
| }, | ||
| objective_metric_name_dict={ | ||
| "trainer1": "accuracy", | ||
| "trainer2": "loss", | ||
| }, | ||
| hyperparameter_ranges_dict={ | ||
| "trainer1": _create_single_hp_range(), | ||
| "trainer2": _create_single_hp_range(), | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These multi-trainer tests ( |
||
| }, | ||
| ) | ||
|
|
||
| # Verify environment is preserved on each trainer | ||
| assert tuner.model_trainer_dict["trainer1"].environment == env1 | ||
| assert tuner.model_trainer_dict["trainer2"].environment == env2 | ||
|
|
||
| def test_get_environment_for_each_trainer_in_dict(self): | ||
| """Test _get_model_trainer_environment for each trainer.""" | ||
| env1 = {"VAR_A": "1"} | ||
| env2 = {"VAR_B": "2"} | ||
| trainer1 = _create_mock_model_trainer(environment=env1) | ||
| trainer2 = _create_mock_model_trainer(environment=env2) | ||
|
|
||
| tuner = HyperparameterTuner.create( | ||
| model_trainer_dict={ | ||
| "trainer1": trainer1, | ||
| "trainer2": trainer2, | ||
| }, | ||
aviruthen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| objective_metric_name_dict={ | ||
| "trainer1": "accuracy", | ||
| "trainer2": "loss", | ||
| }, | ||
| hyperparameter_ranges_dict={ | ||
| "trainer1": _create_single_hp_range(), | ||
| "trainer2": _create_single_hp_range(), | ||
| }, | ||
| ) | ||
|
|
||
| for name, mt in tuner.model_trainer_dict.items(): | ||
| env = HyperparameterTuner._get_model_trainer_environment( | ||
| mt, | ||
| ) | ||
| if name == "trainer1": | ||
| assert env == env1 | ||
| elif name == "trainer2": | ||
| assert env == env2 | ||
|
|
||
| def test_multi_trainer_empty_environment(self): | ||
| """Test multi-trainer with empty environment.""" | ||
| trainer1 = _create_mock_model_trainer(environment={}) | ||
| trainer2 = _create_mock_model_trainer(environment={}) | ||
|
|
||
| tuner = HyperparameterTuner.create( | ||
| model_trainer_dict={ | ||
| "trainer1": trainer1, | ||
| "trainer2": trainer2, | ||
| }, | ||
| objective_metric_name_dict={ | ||
| "trainer1": "accuracy", | ||
| "trainer2": "loss", | ||
| }, | ||
| hyperparameter_ranges_dict={ | ||
| "trainer1": _create_single_hp_range(), | ||
| "trainer2": _create_single_hp_range(), | ||
| }, | ||
| ) | ||
|
|
||
| for _name, mt in tuner.model_trainer_dict.items(): | ||
| env = HyperparameterTuner._get_model_trainer_environment( | ||
| mt, | ||
| ) | ||
| assert env is None, ( | ||
| "Empty environment should return None" | ||
| ) | ||
Uh oh!
There was an error while loading. Please reload this page.