Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sagemaker-core/src/sagemaker/core/shapes/shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8577,7 +8577,7 @@ class InferenceComponentComputeResourceRequirements(Base):
max_memory_required_in_mb: The maximum MB of memory to allocate to run a model that you assign to an inference component.
"""

min_memory_required_in_mb: int
min_memory_required_in_mb: Optional[int] = Unassigned()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this required for this to work ?
My only concern is that this is auto-generated and it would be overriden at some point . We would need to update the engine.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should sort out where this is coming from under the hood, i did get errors describing the endpoint without it

number_of_cpu_cores_required: Optional[float] = Unassigned()
number_of_accelerator_devices_required: Optional[float] = Unassigned()
max_memory_required_in_mb: Optional[int] = Unassigned()
Expand Down
295 changes: 176 additions & 119 deletions sagemaker-serve/src/sagemaker/serve/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -957,6 +957,10 @@ def _fetch_and_cache_recipe_config(self):
if not self.image_uri:
self.image_uri = config.get("EcrAddress")

# Cache environment variables from recipe config
if not self.env_vars:
self.env_vars = config.get("Environment", {})
Comment on lines +961 to +962
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it does exist, should these be appended ?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

guessing yes, we should fix


# Infer instance type from JumpStart metadata if not provided
# This is only called for model_customization deployments
if not self.instance_type:
Expand Down Expand Up @@ -2211,21 +2215,57 @@ def _build_single_modelbuilder(
"Only SageMaker Endpoint Mode is supported for Model Customization use cases"
)
model_package = self._fetch_model_package()
# Fetch recipe config first to set image_uri, instance_type, and s3_upload_path
# Fetch recipe config first to set image_uri, instance_type, env_vars, and s3_upload_path
self._fetch_and_cache_recipe_config()
self.s3_upload_path = model_package.inference_specification.containers[
0
].model_data_source.s3_data_source.s3_uri
container_def = ContainerDefinition(
image=self.image_uri,
model_data_source={
"s3_data_source": {
"s3_uri": f"{self.s3_upload_path}/",
"s3_data_type": "S3Prefix",
"compression_type": "None",
}
},
)
peft_type = self._fetch_peft()

if peft_type == "LORA":
# For LORA: Model points at JumpStart base model, not training output
hub_document = self._fetch_hub_document_for_custom_model()
hosting_artifact_uri = hub_document.get("HostingArtifactUri")
if not hosting_artifact_uri:
raise ValueError(
"HostingArtifactUri not found in JumpStart hub metadata. "
"Cannot deploy LORA adapter without base model artifacts."
)
container_def = ContainerDefinition(
image=self.image_uri,
environment=self.env_vars,
model_data_source={
"s3_data_source": {
"s3_uri": hosting_artifact_uri,
"s3_data_type": "S3Prefix",
"compression_type": "None",
"model_access_config": {"accept_eula": True},
}
},
)
# Store adapter path for use during deploy
if isinstance(self.model, TrainingJob):
self._adapter_s3_uri = (
f"{self.model.model_artifacts.s3_model_artifacts}/checkpoints/hf/"
)
elif isinstance(self.model, ModelTrainer):
self._adapter_s3_uri = (
f"{self.model._latest_training_job.model_artifacts.s3_model_artifacts}"
"/checkpoints/hf/"
)
else:
# Non-LORA: Model points at training output
self.s3_upload_path = model_package.inference_specification.containers[
0
].model_data_source.s3_data_source.s3_uri
container_def = ContainerDefinition(
image=self.image_uri,
model_data_source={
"s3_data_source": {
"s3_uri": self.s3_upload_path.rstrip("/") + "/",
"s3_data_type": "S3Prefix",
"compression_type": "None",
}
},
)

model_name = self.model_name or f"model-{uuid.uuid4().hex[:10]}"
# Create model
self.built_model = Model.create(
Expand Down Expand Up @@ -4142,17 +4182,13 @@ def _deploy_model_customization(
"""Deploy a model customization (fine-tuned) model to an endpoint with inference components.

This method handles the special deployment flow for fine-tuned models, creating:
1. Core Model resource
2. EndpointConfig
3. Endpoint
4. InferenceComponent
1. EndpointConfig and Endpoint
2. Base model InferenceComponent (for LORA: from JumpStart base model)
3. Adapter InferenceComponent (for LORA: referencing base IC with adapter weights)

Args:
endpoint_name (str): Name of the endpoint to create or update
instance_type (str): EC2 instance type for deployment
initial_instance_count (int): Number of instances (default: 1)
wait (bool): Whether to wait for deployment to complete (default: True)
container_timeout_in_seconds (int): Container timeout in seconds (default: 300)
inference_component_name (Optional[str]): Name for the inference component
inference_config (Optional[ResourceRequirements]): Inference configuration including
resource requirements (accelerator count, memory, CPU cores)
Expand All @@ -4161,31 +4197,22 @@ def _deploy_model_customization(
Returns:
Endpoint: The deployed sagemaker.core.resources.Endpoint
"""
from sagemaker.core.resources import (
Model as CoreModel,
EndpointConfig as CoreEndpointConfig,
)
from sagemaker.core.shapes import ContainerDefinition, ProductionVariant
from sagemaker.core.shapes import (
InferenceComponentSpecification,
InferenceComponentContainerSpecification,
InferenceComponentRuntimeConfig,
InferenceComponentComputeResourceRequirements,
ModelDataSource,
S3ModelDataSource,
)
from sagemaker.core.shapes import ProductionVariant
from sagemaker.core.resources import InferenceComponent
from sagemaker.core.utils.utils import Unassigned
from sagemaker.core.resources import Tag as CoreTag

# Fetch model package
model_package = self._fetch_model_package()

# Check if endpoint exists
is_existing_endpoint = self._does_endpoint_exist(endpoint_name)

# Generate model name if not set
model_name = self.model_name or f"model-{uuid.uuid4().hex[:10]}"

if not is_existing_endpoint:
EndpointConfig.create(
endpoint_config_name=endpoint_name,
Expand All @@ -4206,114 +4233,145 @@ def _deploy_model_customization(
else:
endpoint = Endpoint.get(endpoint_name=endpoint_name)

# Set inference component name
if not inference_component_name:
if not is_existing_endpoint:
inference_component_name = f"{endpoint_name}-inference-component"
else:
inference_component_name = f"{endpoint_name}-inference-component-adapter"

# Get PEFT type and base model recipe name
peft_type = self._fetch_peft()
base_model_recipe_name = model_package.inference_specification.containers[
0
].base_model.recipe_name
base_inference_component_name = None
tag = None

# Resolve the correct model artifact URI based on deployment type
artifact_url = self._resolve_model_artifact_uri()

# Determine if this is a base model deployment
# A base model deployment uses HostingArtifactUri from JumpStart (not from model package)
is_base_model_deployment = False
if artifact_url and not peft_type:
# Check if artifact_url comes from JumpStart (not from model package)
# If model package has model_data_source, it's a full fine-tuned model
if (
hasattr(model_package.inference_specification.containers[0], "model_data_source")
and model_package.inference_specification.containers[0].model_data_source
):
is_base_model_deployment = False # Full fine-tuned model
else:
is_base_model_deployment = True # Base model from JumpStart

# Handle tagging and base component lookup
if not is_existing_endpoint and is_base_model_deployment:
# Only tag as "Base" if we're actually deploying a base model
from sagemaker.core.resources import Tag as CoreTag

tag = CoreTag(key="Base", value=base_model_recipe_name)
elif peft_type == "LORA":
# For LORA adapters, look up the existing base component
from sagemaker.core.resources import Tag as CoreTag
if peft_type == "LORA":
# LORA deployment: base IC + adapter IC

# Find or create base IC
base_ic_name = None
for component in InferenceComponent.get_all(
endpoint_name_equals=endpoint_name, status_equals="InService"
):
component_tags = CoreTag.get_all(resource_arn=component.inference_component_arn)
if any(
t.key == "Base" and t.value == base_model_recipe_name for t in component_tags
):
base_inference_component_name = component.inference_component_name
base_ic_name = component.inference_component_name
break

ic_spec = InferenceComponentSpecification(
container=InferenceComponentContainerSpecification(
image=self.image_uri, artifact_url=artifact_url, environment=self.env_vars
if not base_ic_name:
# Deploy base model IC
base_ic_name = f"{endpoint_name}-inference-component"

base_ic_spec = InferenceComponentSpecification(
model_name=self.built_model.model_name,
)
if inference_config is not None:
base_ic_spec.compute_resource_requirements = (
InferenceComponentComputeResourceRequirements(
min_memory_required_in_mb=inference_config.min_memory,
max_memory_required_in_mb=inference_config.max_memory,
number_of_cpu_cores_required=inference_config.num_cpus,
number_of_accelerator_devices_required=inference_config.num_accelerators,
)
)
else:
base_ic_spec.compute_resource_requirements = self._cached_compute_requirements

InferenceComponent.create(
inference_component_name=base_ic_name,
endpoint_name=endpoint_name,
variant_name=endpoint_name,
specification=base_ic_spec,
runtime_config=InferenceComponentRuntimeConfig(copy_count=1),
tags=[{"key": "Base", "value": base_model_recipe_name}],
)
logger.info("Created base model InferenceComponent: '%s'", base_ic_name)

# Wait for base IC to be InService before creating adapter
base_ic = InferenceComponent.get(inference_component_name=base_ic_name)
base_ic.wait_for_status("InService")

# Deploy adapter IC
adapter_ic_name = inference_component_name or f"{endpoint_name}-adapter"
adapter_s3_uri = getattr(self, "_adapter_s3_uri", None)

adapter_ic_spec = InferenceComponentSpecification(
base_inference_component_name=base_ic_name,
container=InferenceComponentContainerSpecification(
artifact_url=adapter_s3_uri,
),
)
)

if peft_type == "LORA":
ic_spec.base_inference_component_name = base_inference_component_name

# Use inference_config if provided, otherwise fall back to cached requirements
if inference_config is not None:
# Extract compute requirements from inference_config (ResourceRequirements)
ic_spec.compute_resource_requirements = InferenceComponentComputeResourceRequirements(
min_memory_required_in_mb=inference_config.min_memory,
max_memory_required_in_mb=inference_config.max_memory,
number_of_cpu_cores_required=inference_config.num_cpus,
number_of_accelerator_devices_required=inference_config.num_accelerators,
InferenceComponent.create(
inference_component_name=adapter_ic_name,
endpoint_name=endpoint_name,
specification=adapter_ic_spec,
)
logger.info("Created adapter InferenceComponent: '%s'", adapter_ic_name)

else:
# Fall back to resolved compute requirements from build()
ic_spec.compute_resource_requirements = self._cached_compute_requirements
# Non-LORA deployment: single IC
if not inference_component_name:
inference_component_name = f"{endpoint_name}-inference-component"

InferenceComponent.create(
inference_component_name=inference_component_name,
endpoint_name=endpoint_name,
variant_name=endpoint_name,
specification=ic_spec,
runtime_config=InferenceComponentRuntimeConfig(copy_count=1),
tags=[{"key": tag.key, "value": tag.value}] if tag else [],
)
artifact_url = self._resolve_model_artifact_uri()

ic_spec = InferenceComponentSpecification(
container=InferenceComponentContainerSpecification(
image=self.image_uri, artifact_url=artifact_url, environment=self.env_vars
)
)

if inference_config is not None:
ic_spec.compute_resource_requirements = (
InferenceComponentComputeResourceRequirements(
min_memory_required_in_mb=inference_config.min_memory,
max_memory_required_in_mb=inference_config.max_memory,
number_of_cpu_cores_required=inference_config.num_cpus,
number_of_accelerator_devices_required=inference_config.num_accelerators,
)
)
else:
ic_spec.compute_resource_requirements = self._cached_compute_requirements

InferenceComponent.create(
inference_component_name=inference_component_name,
endpoint_name=endpoint_name,
variant_name=endpoint_name,
specification=ic_spec,
runtime_config=InferenceComponentRuntimeConfig(copy_count=1),
)

# Create lineage tracking for new endpoints
if not is_existing_endpoint:
from sagemaker.core.resources import Action, Association, Artifact
from sagemaker.core.shapes import ActionSource, MetadataProperties
try:
from sagemaker.core.resources import Action, Association, Artifact
from sagemaker.core.shapes import ActionSource, MetadataProperties

inference_component = InferenceComponent.get(
inference_component_name=inference_component_name
)
ic_name = (
inference_component_name
if not peft_type == "LORA"
else adapter_ic_name
)
inference_component = InferenceComponent.get(
inference_component_name=ic_name
)

action = Action.create(
source=ActionSource(
source_uri=self._fetch_model_package_arn(), source_type="SageMaker"
),
action_name=f"{endpoint_name}-action",
action_type="ModelDeployment",
properties={"EndpointConfigName": endpoint_name},
metadata_properties=MetadataProperties(
generated_by=inference_component.inference_component_arn
),
)
action = Action.create(
source=ActionSource(
source_uri=self._fetch_model_package_arn(), source_type="SageMaker"
),
action_name=f"{endpoint_name}-action",
action_type="ModelDeployment",
properties={"EndpointConfigName": endpoint_name},
metadata_properties=MetadataProperties(
generated_by=inference_component.inference_component_arn
),
)

artifacts = Artifact.get_all(source_uri=model_package.model_package_arn)
for artifact in artifacts:
Association.add(source_arn=artifact.artifact_arn, destination_arn=action.action_arn)
break
artifacts = Artifact.get_all(source_uri=model_package.model_package_arn)
for artifact in artifacts:
Association.add(
source_arn=artifact.artifact_arn, destination_arn=action.action_arn
)
break
except Exception as e:
logger.warning(f"Failed to create lineage tracking: {e}")

logger.info("✅ Model customization deployment successful: Endpoint '%s'", endpoint_name)
return endpoint
Expand All @@ -4329,11 +4387,10 @@ def _fetch_peft(self) -> Optional[str]:

from sagemaker.core.utils.utils import Unassigned

if (
training_job.serverless_job_config != Unassigned()
and training_job.serverless_job_config.job_spec != Unassigned()
):
return training_job.serverless_job_config.job_spec.get("PEFT")
if training_job.serverless_job_config != Unassigned():
peft = getattr(training_job.serverless_job_config, "peft", None)
if peft and not isinstance(peft, Unassigned):
return peft
return None

def _does_endpoint_exist(self, endpoint_name: str) -> bool:
Expand Down
Loading