-
Notifications
You must be signed in to change notification settings - Fork 1.2k
fixes for model builder #5625
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
joshuatowner
wants to merge
2
commits into
aws:master
Choose a base branch
from
joshuatowner:model-builder-fixes
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
fixes for model builder #5625
Changes from 1 commit
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -957,6 +957,10 @@ def _fetch_and_cache_recipe_config(self): | |
| if not self.image_uri: | ||
| self.image_uri = config.get("EcrAddress") | ||
|
|
||
| # Cache environment variables from recipe config | ||
| if not self.env_vars: | ||
| self.env_vars = config.get("Environment", {}) | ||
|
Comment on lines
+961
to
+962
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If it does exist, should these be appended ?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. guessing yes, we should fix |
||
|
|
||
| # Infer instance type from JumpStart metadata if not provided | ||
| # This is only called for model_customization deployments | ||
| if not self.instance_type: | ||
|
|
@@ -2211,21 +2215,57 @@ def _build_single_modelbuilder( | |
| "Only SageMaker Endpoint Mode is supported for Model Customization use cases" | ||
| ) | ||
| model_package = self._fetch_model_package() | ||
| # Fetch recipe config first to set image_uri, instance_type, and s3_upload_path | ||
| # Fetch recipe config first to set image_uri, instance_type, env_vars, and s3_upload_path | ||
| self._fetch_and_cache_recipe_config() | ||
| self.s3_upload_path = model_package.inference_specification.containers[ | ||
| 0 | ||
| ].model_data_source.s3_data_source.s3_uri | ||
| container_def = ContainerDefinition( | ||
| image=self.image_uri, | ||
| model_data_source={ | ||
| "s3_data_source": { | ||
| "s3_uri": f"{self.s3_upload_path}/", | ||
| "s3_data_type": "S3Prefix", | ||
| "compression_type": "None", | ||
| } | ||
| }, | ||
| ) | ||
| peft_type = self._fetch_peft() | ||
|
|
||
| if peft_type == "LORA": | ||
| # For LORA: Model points at JumpStart base model, not training output | ||
| hub_document = self._fetch_hub_document_for_custom_model() | ||
| hosting_artifact_uri = hub_document.get("HostingArtifactUri") | ||
| if not hosting_artifact_uri: | ||
| raise ValueError( | ||
| "HostingArtifactUri not found in JumpStart hub metadata. " | ||
| "Cannot deploy LORA adapter without base model artifacts." | ||
| ) | ||
| container_def = ContainerDefinition( | ||
| image=self.image_uri, | ||
| environment=self.env_vars, | ||
| model_data_source={ | ||
| "s3_data_source": { | ||
| "s3_uri": hosting_artifact_uri, | ||
| "s3_data_type": "S3Prefix", | ||
| "compression_type": "None", | ||
| "model_access_config": {"accept_eula": True}, | ||
| } | ||
| }, | ||
| ) | ||
| # Store adapter path for use during deploy | ||
| if isinstance(self.model, TrainingJob): | ||
| self._adapter_s3_uri = ( | ||
| f"{self.model.model_artifacts.s3_model_artifacts}/checkpoints/hf/" | ||
| ) | ||
| elif isinstance(self.model, ModelTrainer): | ||
| self._adapter_s3_uri = ( | ||
| f"{self.model._latest_training_job.model_artifacts.s3_model_artifacts}" | ||
| "/checkpoints/hf/" | ||
| ) | ||
| else: | ||
| # Non-LORA: Model points at training output | ||
| self.s3_upload_path = model_package.inference_specification.containers[ | ||
| 0 | ||
| ].model_data_source.s3_data_source.s3_uri | ||
| container_def = ContainerDefinition( | ||
| image=self.image_uri, | ||
| model_data_source={ | ||
| "s3_data_source": { | ||
| "s3_uri": self.s3_upload_path.rstrip("/") + "/", | ||
| "s3_data_type": "S3Prefix", | ||
| "compression_type": "None", | ||
| } | ||
| }, | ||
| ) | ||
|
|
||
| model_name = self.model_name or f"model-{uuid.uuid4().hex[:10]}" | ||
| # Create model | ||
| self.built_model = Model.create( | ||
|
|
@@ -4142,17 +4182,13 @@ def _deploy_model_customization( | |
| """Deploy a model customization (fine-tuned) model to an endpoint with inference components. | ||
|
|
||
| This method handles the special deployment flow for fine-tuned models, creating: | ||
| 1. Core Model resource | ||
| 2. EndpointConfig | ||
| 3. Endpoint | ||
| 4. InferenceComponent | ||
| 1. EndpointConfig and Endpoint | ||
| 2. Base model InferenceComponent (for LORA: from JumpStart base model) | ||
| 3. Adapter InferenceComponent (for LORA: referencing base IC with adapter weights) | ||
|
|
||
| Args: | ||
| endpoint_name (str): Name of the endpoint to create or update | ||
| instance_type (str): EC2 instance type for deployment | ||
| initial_instance_count (int): Number of instances (default: 1) | ||
| wait (bool): Whether to wait for deployment to complete (default: True) | ||
| container_timeout_in_seconds (int): Container timeout in seconds (default: 300) | ||
| inference_component_name (Optional[str]): Name for the inference component | ||
| inference_config (Optional[ResourceRequirements]): Inference configuration including | ||
| resource requirements (accelerator count, memory, CPU cores) | ||
|
|
@@ -4161,31 +4197,22 @@ def _deploy_model_customization( | |
| Returns: | ||
| Endpoint: The deployed sagemaker.core.resources.Endpoint | ||
| """ | ||
| from sagemaker.core.resources import ( | ||
| Model as CoreModel, | ||
| EndpointConfig as CoreEndpointConfig, | ||
| ) | ||
| from sagemaker.core.shapes import ContainerDefinition, ProductionVariant | ||
| from sagemaker.core.shapes import ( | ||
| InferenceComponentSpecification, | ||
| InferenceComponentContainerSpecification, | ||
| InferenceComponentRuntimeConfig, | ||
| InferenceComponentComputeResourceRequirements, | ||
| ModelDataSource, | ||
| S3ModelDataSource, | ||
| ) | ||
| from sagemaker.core.shapes import ProductionVariant | ||
| from sagemaker.core.resources import InferenceComponent | ||
| from sagemaker.core.utils.utils import Unassigned | ||
| from sagemaker.core.resources import Tag as CoreTag | ||
|
|
||
| # Fetch model package | ||
| model_package = self._fetch_model_package() | ||
|
|
||
| # Check if endpoint exists | ||
| is_existing_endpoint = self._does_endpoint_exist(endpoint_name) | ||
|
|
||
| # Generate model name if not set | ||
| model_name = self.model_name or f"model-{uuid.uuid4().hex[:10]}" | ||
|
|
||
| if not is_existing_endpoint: | ||
| EndpointConfig.create( | ||
| endpoint_config_name=endpoint_name, | ||
|
|
@@ -4206,114 +4233,145 @@ def _deploy_model_customization( | |
| else: | ||
| endpoint = Endpoint.get(endpoint_name=endpoint_name) | ||
|
|
||
| # Set inference component name | ||
| if not inference_component_name: | ||
| if not is_existing_endpoint: | ||
| inference_component_name = f"{endpoint_name}-inference-component" | ||
| else: | ||
| inference_component_name = f"{endpoint_name}-inference-component-adapter" | ||
|
|
||
| # Get PEFT type and base model recipe name | ||
| peft_type = self._fetch_peft() | ||
| base_model_recipe_name = model_package.inference_specification.containers[ | ||
| 0 | ||
| ].base_model.recipe_name | ||
| base_inference_component_name = None | ||
| tag = None | ||
|
|
||
| # Resolve the correct model artifact URI based on deployment type | ||
| artifact_url = self._resolve_model_artifact_uri() | ||
|
|
||
| # Determine if this is a base model deployment | ||
| # A base model deployment uses HostingArtifactUri from JumpStart (not from model package) | ||
| is_base_model_deployment = False | ||
| if artifact_url and not peft_type: | ||
| # Check if artifact_url comes from JumpStart (not from model package) | ||
| # If model package has model_data_source, it's a full fine-tuned model | ||
| if ( | ||
| hasattr(model_package.inference_specification.containers[0], "model_data_source") | ||
| and model_package.inference_specification.containers[0].model_data_source | ||
| ): | ||
| is_base_model_deployment = False # Full fine-tuned model | ||
| else: | ||
| is_base_model_deployment = True # Base model from JumpStart | ||
|
|
||
| # Handle tagging and base component lookup | ||
| if not is_existing_endpoint and is_base_model_deployment: | ||
| # Only tag as "Base" if we're actually deploying a base model | ||
| from sagemaker.core.resources import Tag as CoreTag | ||
|
|
||
| tag = CoreTag(key="Base", value=base_model_recipe_name) | ||
| elif peft_type == "LORA": | ||
| # For LORA adapters, look up the existing base component | ||
| from sagemaker.core.resources import Tag as CoreTag | ||
| if peft_type == "LORA": | ||
| # LORA deployment: base IC + adapter IC | ||
|
|
||
| # Find or create base IC | ||
| base_ic_name = None | ||
| for component in InferenceComponent.get_all( | ||
| endpoint_name_equals=endpoint_name, status_equals="InService" | ||
| ): | ||
| component_tags = CoreTag.get_all(resource_arn=component.inference_component_arn) | ||
| if any( | ||
| t.key == "Base" and t.value == base_model_recipe_name for t in component_tags | ||
| ): | ||
| base_inference_component_name = component.inference_component_name | ||
| base_ic_name = component.inference_component_name | ||
| break | ||
|
|
||
| ic_spec = InferenceComponentSpecification( | ||
| container=InferenceComponentContainerSpecification( | ||
| image=self.image_uri, artifact_url=artifact_url, environment=self.env_vars | ||
| if not base_ic_name: | ||
| # Deploy base model IC | ||
| base_ic_name = f"{endpoint_name}-inference-component" | ||
|
|
||
| base_ic_spec = InferenceComponentSpecification( | ||
| model_name=self.built_model.model_name, | ||
| ) | ||
| if inference_config is not None: | ||
| base_ic_spec.compute_resource_requirements = ( | ||
| InferenceComponentComputeResourceRequirements( | ||
| min_memory_required_in_mb=inference_config.min_memory, | ||
| max_memory_required_in_mb=inference_config.max_memory, | ||
| number_of_cpu_cores_required=inference_config.num_cpus, | ||
| number_of_accelerator_devices_required=inference_config.num_accelerators, | ||
| ) | ||
| ) | ||
| else: | ||
| base_ic_spec.compute_resource_requirements = self._cached_compute_requirements | ||
|
|
||
| InferenceComponent.create( | ||
| inference_component_name=base_ic_name, | ||
| endpoint_name=endpoint_name, | ||
| variant_name=endpoint_name, | ||
| specification=base_ic_spec, | ||
| runtime_config=InferenceComponentRuntimeConfig(copy_count=1), | ||
| tags=[{"key": "Base", "value": base_model_recipe_name}], | ||
| ) | ||
| logger.info("Created base model InferenceComponent: '%s'", base_ic_name) | ||
|
|
||
| # Wait for base IC to be InService before creating adapter | ||
| base_ic = InferenceComponent.get(inference_component_name=base_ic_name) | ||
| base_ic.wait_for_status("InService") | ||
|
|
||
| # Deploy adapter IC | ||
| adapter_ic_name = inference_component_name or f"{endpoint_name}-adapter" | ||
| adapter_s3_uri = getattr(self, "_adapter_s3_uri", None) | ||
|
|
||
| adapter_ic_spec = InferenceComponentSpecification( | ||
| base_inference_component_name=base_ic_name, | ||
| container=InferenceComponentContainerSpecification( | ||
| artifact_url=adapter_s3_uri, | ||
| ), | ||
| ) | ||
| ) | ||
|
|
||
| if peft_type == "LORA": | ||
| ic_spec.base_inference_component_name = base_inference_component_name | ||
|
|
||
| # Use inference_config if provided, otherwise fall back to cached requirements | ||
| if inference_config is not None: | ||
| # Extract compute requirements from inference_config (ResourceRequirements) | ||
| ic_spec.compute_resource_requirements = InferenceComponentComputeResourceRequirements( | ||
| min_memory_required_in_mb=inference_config.min_memory, | ||
| max_memory_required_in_mb=inference_config.max_memory, | ||
| number_of_cpu_cores_required=inference_config.num_cpus, | ||
| number_of_accelerator_devices_required=inference_config.num_accelerators, | ||
| InferenceComponent.create( | ||
| inference_component_name=adapter_ic_name, | ||
| endpoint_name=endpoint_name, | ||
| specification=adapter_ic_spec, | ||
| ) | ||
| logger.info("Created adapter InferenceComponent: '%s'", adapter_ic_name) | ||
|
|
||
| else: | ||
| # Fall back to resolved compute requirements from build() | ||
| ic_spec.compute_resource_requirements = self._cached_compute_requirements | ||
| # Non-LORA deployment: single IC | ||
| if not inference_component_name: | ||
| inference_component_name = f"{endpoint_name}-inference-component" | ||
|
|
||
| InferenceComponent.create( | ||
| inference_component_name=inference_component_name, | ||
| endpoint_name=endpoint_name, | ||
| variant_name=endpoint_name, | ||
| specification=ic_spec, | ||
| runtime_config=InferenceComponentRuntimeConfig(copy_count=1), | ||
| tags=[{"key": tag.key, "value": tag.value}] if tag else [], | ||
| ) | ||
| artifact_url = self._resolve_model_artifact_uri() | ||
|
|
||
| ic_spec = InferenceComponentSpecification( | ||
| container=InferenceComponentContainerSpecification( | ||
| image=self.image_uri, artifact_url=artifact_url, environment=self.env_vars | ||
| ) | ||
| ) | ||
|
|
||
| if inference_config is not None: | ||
| ic_spec.compute_resource_requirements = ( | ||
| InferenceComponentComputeResourceRequirements( | ||
| min_memory_required_in_mb=inference_config.min_memory, | ||
| max_memory_required_in_mb=inference_config.max_memory, | ||
| number_of_cpu_cores_required=inference_config.num_cpus, | ||
| number_of_accelerator_devices_required=inference_config.num_accelerators, | ||
| ) | ||
| ) | ||
| else: | ||
| ic_spec.compute_resource_requirements = self._cached_compute_requirements | ||
|
|
||
| InferenceComponent.create( | ||
| inference_component_name=inference_component_name, | ||
| endpoint_name=endpoint_name, | ||
| variant_name=endpoint_name, | ||
| specification=ic_spec, | ||
| runtime_config=InferenceComponentRuntimeConfig(copy_count=1), | ||
| ) | ||
|
|
||
| # Create lineage tracking for new endpoints | ||
| if not is_existing_endpoint: | ||
| from sagemaker.core.resources import Action, Association, Artifact | ||
| from sagemaker.core.shapes import ActionSource, MetadataProperties | ||
| try: | ||
| from sagemaker.core.resources import Action, Association, Artifact | ||
| from sagemaker.core.shapes import ActionSource, MetadataProperties | ||
|
|
||
| inference_component = InferenceComponent.get( | ||
| inference_component_name=inference_component_name | ||
| ) | ||
| ic_name = ( | ||
| inference_component_name | ||
| if not peft_type == "LORA" | ||
| else adapter_ic_name | ||
| ) | ||
| inference_component = InferenceComponent.get( | ||
| inference_component_name=ic_name | ||
| ) | ||
|
|
||
| action = Action.create( | ||
| source=ActionSource( | ||
| source_uri=self._fetch_model_package_arn(), source_type="SageMaker" | ||
| ), | ||
| action_name=f"{endpoint_name}-action", | ||
| action_type="ModelDeployment", | ||
| properties={"EndpointConfigName": endpoint_name}, | ||
| metadata_properties=MetadataProperties( | ||
| generated_by=inference_component.inference_component_arn | ||
| ), | ||
| ) | ||
| action = Action.create( | ||
| source=ActionSource( | ||
| source_uri=self._fetch_model_package_arn(), source_type="SageMaker" | ||
| ), | ||
| action_name=f"{endpoint_name}-action", | ||
| action_type="ModelDeployment", | ||
| properties={"EndpointConfigName": endpoint_name}, | ||
| metadata_properties=MetadataProperties( | ||
| generated_by=inference_component.inference_component_arn | ||
| ), | ||
| ) | ||
|
|
||
| artifacts = Artifact.get_all(source_uri=model_package.model_package_arn) | ||
| for artifact in artifacts: | ||
| Association.add(source_arn=artifact.artifact_arn, destination_arn=action.action_arn) | ||
| break | ||
| artifacts = Artifact.get_all(source_uri=model_package.model_package_arn) | ||
| for artifact in artifacts: | ||
| Association.add( | ||
| source_arn=artifact.artifact_arn, destination_arn=action.action_arn | ||
| ) | ||
| break | ||
| except Exception as e: | ||
| logger.warning(f"Failed to create lineage tracking: {e}") | ||
|
|
||
| logger.info("✅ Model customization deployment successful: Endpoint '%s'", endpoint_name) | ||
| return endpoint | ||
|
|
@@ -4329,11 +4387,10 @@ def _fetch_peft(self) -> Optional[str]: | |
|
|
||
| from sagemaker.core.utils.utils import Unassigned | ||
|
|
||
| if ( | ||
| training_job.serverless_job_config != Unassigned() | ||
| and training_job.serverless_job_config.job_spec != Unassigned() | ||
| ): | ||
| return training_job.serverless_job_config.job_spec.get("PEFT") | ||
| if training_job.serverless_job_config != Unassigned(): | ||
| peft = getattr(training_job.serverless_job_config, "peft", None) | ||
| if peft and not isinstance(peft, Unassigned): | ||
| return peft | ||
| return None | ||
|
|
||
| def _does_endpoint_exist(self, endpoint_name: str) -> bool: | ||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this required for this to work ?
My only concern is that this is auto-generated and it would be overriden at some point . We would need to update the engine.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should sort out where this is coming from under the hood, i did get errors describing the endpoint without it