-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Expand file tree
/
Copy pathserver.py
More file actions
120 lines (104 loc) · 4.21 KB
/
server.py
File metadata and controls
120 lines (104 loc) · 4.21 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
"""Module for Local Torch Server"""
from __future__ import absolute_import
import requests
import logging
import platform
from pathlib import Path
from sagemaker.base_predictor import PredictorBase
from sagemaker.serve.utils.optimize_utils import _is_s3_uri
from sagemaker.session import Session
from sagemaker.serve.utils.exceptions import LocalModelInvocationException
from sagemaker.s3_utils import determine_bucket_and_prefix, parse_s3_url
from sagemaker import fw_utils
from sagemaker.serve.utils.uploader import upload
from sagemaker.local.utils import get_docker_host
logger = logging.getLogger(__name__)
class LocalTorchServe:
"""Placeholder docstring"""
def _start_torch_serve(
self, client: object, image: str, model_path: str, secret_key: str, env_vars: dict
):
"""Placeholder docstring"""
self.container = client.containers.run(
image,
"serve",
detach=True,
auto_remove=True,
network_mode="host",
volumes={
Path(model_path): {
"bind": "/opt/ml/model",
"mode": "rw",
},
},
environment={
"SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code",
"SAGEMAKER_PROGRAM": "inference.py",
"LOCAL_PYTHON": platform.python_version(),
**env_vars,
},
)
def _invoke_torch_serve(self, request: object, content_type: str, accept: str):
"""Placeholder docstring"""
try:
response = requests.post(
f"http://{get_docker_host()}:8080/invocations",
data=request,
headers={"Content-Type": content_type, "Accept": accept},
timeout=60, # this is what SageMaker Hosting uses as timeout
)
response.raise_for_status()
return response.content
except Exception as e:
raise Exception("Unable to send request to the local container server") from e
def _torchserve_deep_ping(self, predictor: PredictorBase):
"""Placeholder docstring"""
response = None
try:
response = predictor.predict(self.schema_builder.sample_input)
return (True, response)
# pylint: disable=broad-except
except Exception as e:
if "422 Client Error: Unprocessable Entity for url" in str(e):
raise LocalModelInvocationException(str(e))
return (False, response)
return (True, response)
class SageMakerTorchServe:
"""Placeholder docstring"""
def _upload_torchserve_artifacts(
self,
model_path: str,
sagemaker_session: Session,
secret_key: str,
s3_model_data_url: str = None,
image: str = None,
should_upload_artifacts: bool = False,
):
"""Tar the model artifact and upload to S3 bucket, then prepare for the environment variables"""
s3_upload_path = None
if _is_s3_uri(model_path):
s3_upload_path = model_path
elif should_upload_artifacts:
if s3_model_data_url:
bucket, key_prefix = parse_s3_url(url=s3_model_data_url)
else:
bucket, key_prefix = None, None
code_key_prefix = fw_utils.model_code_key_prefix(key_prefix, None, image)
bucket, code_key_prefix = determine_bucket_and_prefix(
bucket=bucket, key_prefix=code_key_prefix, sagemaker_session=sagemaker_session
)
logger.debug(
"Uploading the model resources to bucket=%s, key_prefix=%s.",
bucket,
code_key_prefix,
)
s3_upload_path = upload(sagemaker_session, model_path, bucket, code_key_prefix)
logger.debug("Model resources uploaded to: %s", s3_upload_path)
env_vars = {
"SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code",
"SAGEMAKER_PROGRAM": "inference.py",
"SAGEMAKER_REGION": sagemaker_session.boto_region_name,
"SAGEMAKER_CONTAINER_LOG_LEVEL": "10",
"LOCAL_PYTHON": platform.python_version(),
}
return s3_upload_path, env_vars