Skip to content

Commit e3b1eec

Browse files
Adds encryption module to _internal
1 parent 0e64d7e commit e3b1eec

9 files changed

Lines changed: 595 additions & 0 deletions

File tree

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from .aes_cypher import AESCipher
2+
from .client_encryption import intersect_client_encryption
3+
from .service_decryption import intersect_service_decryption
4+
5+
6+
__all__ = (
7+
'AESCipher',
8+
'intersect_client_encryption',
9+
'intersect_service_decryption',
10+
)
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
"""
2+
AESCipher class for intersect_sdk._internal.encryption.aes_cypher
3+
"""
4+
import base64
5+
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
6+
from cryptography.hazmat.primitives.padding import PKCS7
7+
from cryptography.hazmat.backends import default_backend
8+
9+
10+
class AESCipher:
11+
def __init__(self, key: bytes, initialization_vector: bytes):
12+
if len(key) * 8 != 256:
13+
raise ValueError("Invalid key size (must be 256 bit)")
14+
self._key: bytes = key
15+
# No need to encrypt initialization vectors
16+
# https://cryptography.io/en/latest/hazmat/primitives/symmetric-encryption/#cryptography.hazmat.primitives.ciphers.modes.CBC
17+
self._initial_vector: bytes = initialization_vector
18+
self._cipher: Cipher = Cipher(
19+
algorithms.AES(self._key),
20+
mode=modes.CBC(self._initial_vector),
21+
backend=default_backend(),
22+
)
23+
24+
def encrypt(self, plaintext: bytes) -> bytes:
25+
encryptor = self._cipher.encryptor()
26+
bytes_to_bits = 8
27+
padder = PKCS7(len(self._initial_vector) * bytes_to_bits).padder()
28+
padded_data = padder.update(plaintext) + padder.finalize()
29+
ciphertext = encryptor.update(padded_data) + encryptor.finalize()
30+
encodedciphertext = base64.b64encode(ciphertext)
31+
return encodedciphertext
32+
33+
def decrypt(self, ciphertext: bytes) -> bytes:
34+
decryptor = self._cipher.decryptor()
35+
decodedciphertext = base64.b64decode(ciphertext)
36+
padded_data = decryptor.update(decodedciphertext) + decryptor.finalize()
37+
bytes_to_bits = 8
38+
unpadder = PKCS7(len(self._initial_vector) * bytes_to_bits).unpadder()
39+
plaintext = unpadder.update(padded_data) + unpadder.finalize()
40+
return plaintext
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
"""
2+
Client encryption function for intersect_sdk._internal.encryption.client_encryption
3+
RSA asymmetric encryption and AES symmetric encryption
4+
"""
5+
# RSA asymmetric encryption and AES symmetric encryption
6+
import base64
7+
from cryptography.hazmat.primitives.asymmetric import padding
8+
from cryptography.hazmat.primitives import hashes, serialization
9+
from cryptography.hazmat.backends import default_backend
10+
import os
11+
12+
13+
from .aes_cypher import AESCipher
14+
from .models import IntersectEncryptedPayload, IntersectEncryptionPublicKey
15+
from ..logger import logger
16+
17+
def intersect_client_encryption(
18+
key_payload: IntersectEncryptionPublicKey,
19+
unencrypted_model: str,
20+
) -> IntersectEncryptedPayload:
21+
# Decode and deserialize the public key for asymmetric encrypt
22+
public_rsa_key = serialization.load_pem_public_key(
23+
key_payload.public_key.encode(),
24+
backend=default_backend(),
25+
)
26+
27+
# Setup AES
28+
aes_key: bytes = os.urandom(32)
29+
aes_initialization_vector: bytes = os.urandom(16)
30+
cipher = AESCipher(
31+
key=aes_key,
32+
initialization_vector=aes_initialization_vector,
33+
)
34+
35+
# Encrypt using AES
36+
logger.info("AES encrypting payload...")
37+
encrypted_data = cipher.encrypt(unencrypted_model.encode())
38+
logger.info("AES encrypted the payload!")
39+
40+
# Asymmetric RSA encrypts AES symmetric key using RSA public key
41+
logger.info("RSA encrypting AES key...")
42+
encrypted_aes_key: bytes = public_rsa_key.encrypt(
43+
aes_key,
44+
padding.OAEP(
45+
mgf=padding.MGF1(algorithm=hashes.SHA256()),
46+
algorithm=hashes.SHA256(),
47+
label=None,
48+
),
49+
)
50+
logger.info("RSA encrypted AES key!")
51+
logger.info("Done encrypting!")
52+
53+
return IntersectEncryptedPayload(
54+
key=base64.b64encode(encrypted_aes_key).decode(),
55+
initial_vector=base64.b64encode(aes_initialization_vector).decode(),
56+
data=base64.b64encode(encrypted_data).decode(),
57+
)
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from .decrypted_payload import IntersectDecryptedPayload
2+
from .encrypted_payload import IntersectEncryptedPayload
3+
from .public_key import IntersectEncryptionPublicKey
4+
5+
__all__ = (
6+
"IntersectDecryptedPayload",
7+
"IntersectEncryptedPayload",
8+
"IntersectEncryptionPublicKey",
9+
)
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from pydantic import BaseModel
2+
3+
class IntersectDecryptedPayload(BaseModel):
4+
model: BaseModel
5+
aes_key: bytes
6+
aes_initialization_vector: bytes
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from pydantic import BaseModel
2+
3+
4+
class IntersectEncryptedPayload(BaseModel):
5+
key: str
6+
initial_vector: str
7+
data: str
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
"""Definitions supporting the 'core status' functionality of the core capability."""
2+
3+
from __future__ import annotations
4+
5+
from pydantic import BaseModel, Field
6+
from typing import Annotated
7+
8+
9+
class IntersectEncryptionPublicKey(BaseModel):
10+
"""Public key information for encryption within the INTERSECT-SDK Service."""
11+
12+
public_key: Annotated[str, Field(title='Public Key PEM')]
13+
"""The PEM encoded public key for asymmetric encryption."""
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
"""
2+
Service decryption function for intersect_sdk._internal.encryption.service_decryption
3+
RSA asymmetric encryption and AES symmetric encryption
4+
"""
5+
6+
import json
7+
from typing import Dict, Type
8+
from cryptography.hazmat.primitives.asymmetric import rsa, padding
9+
from cryptography.hazmat.primitives import hashes
10+
import base64
11+
12+
from pydantic import BaseModel
13+
14+
from .aes_cypher import AESCipher
15+
from .models import IntersectEncryptedPayload, IntersectDecryptedPayload
16+
from ..logger import logger
17+
18+
19+
def intersect_service_decryption(
20+
rsa_private_key: rsa.RSAPrivateKey,
21+
encrypted_payload: IntersectEncryptedPayload,
22+
model: Type[BaseModel],
23+
) -> IntersectDecryptedPayload:
24+
# Decode base64 encoded values from payload
25+
encrypted_aes_key = base64.b64decode(encrypted_payload.key.encode())
26+
initial_vector = base64.b64decode(encrypted_payload.initial_vector.encode())
27+
encrypted_data = base64.b64decode(encrypted_payload.data.encode())
28+
29+
# Decrypt AES key using RSA
30+
logger.info("RSA decrypting AES key...")
31+
decrypted_aes_key: bytes = rsa_private_key.decrypt(
32+
encrypted_aes_key,
33+
padding.OAEP(
34+
mgf=padding.MGF1(algorithm=hashes.SHA256()),
35+
algorithm=hashes.SHA256(),
36+
label=None,
37+
),
38+
)
39+
logger.info("RSA decrypted AES key!")
40+
41+
# Encrypt using AES
42+
logger.info("AES decrypting payload...")
43+
cipher = AESCipher(
44+
key=decrypted_aes_key,
45+
initialization_vector=initial_vector,
46+
)
47+
48+
decrypted_payload: bytes = cipher.decrypt(encrypted_data)
49+
unencrypted_payload = decrypted_payload.decode()
50+
logger.info("AES decrypted payload!")
51+
52+
# Cast unencrypted payload to model and return AES key and IV for possible re-use
53+
return IntersectDecryptedPayload(
54+
model=model(**json.loads(unencrypted_payload)),
55+
aes_key=decrypted_aes_key,
56+
aes_initialization_vector=initial_vector,
57+
)

0 commit comments

Comments
 (0)