Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sqlmodel/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def partial_init() -> Generator[None, None, None]:
class SQLModelConfig(BaseConfig, total=False):
table: bool | None
registry: Any | None
model_fields_optional: str | None


def get_model_fields(model: InstanceOrType[BaseModel]) -> dict[str, "FieldInfo"]:
Expand Down
51 changes: 51 additions & 0 deletions sqlmodel/main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import builtins
import copy
import ipaddress
import uuid
import weakref
Expand All @@ -19,6 +20,7 @@
TypeVar,
Union,
cast,
get_args,
get_origin,
overload,
)
Expand Down Expand Up @@ -56,10 +58,12 @@
PYDANTIC_MINOR_VERSION,
BaseConfig,
ModelMetaclass,
NoneType,
Representation,
SQLModelConfig,
Undefined,
UndefinedType,
_is_union_type,
finish_init,
get_annotations,
get_field_metadata,
Expand Down Expand Up @@ -91,6 +95,17 @@
OnDeleteType = Literal["CASCADE", "SET NULL", "RESTRICT"]


def _is_optional_annotation(annotation: Any) -> bool:
"""Check if a type annotation is already Optional (i.e., Union[X, None])."""
origin = get_origin(annotation)
if origin is not None and _is_union_type(origin):
args = get_args(annotation)
return NoneType in args
if annotation is NoneType:
return True
return False


def __dataclass_transform__(
*,
eq_default: bool = True,
Expand Down Expand Up @@ -560,6 +575,42 @@ def __new__(
relationship_annotations[k] = v
else:
pydantic_annotations[k] = v

# Handle model_fields_optional: make all inherited fields Optional
# with a default of None
model_fields_optional = kwargs.pop("model_fields_optional", None)
if model_fields_optional is None:
# Also check model_config in class_dict
config_dict = class_dict.get("model_config", {})
if isinstance(config_dict, dict):
model_fields_optional = config_dict.get("model_fields_optional", None)
if model_fields_optional == "all":
for base in bases:
base_fields = (
get_model_fields(base) if hasattr(base, "model_fields") else {}
)
for field_name, field_info in base_fields.items():
# Only modify fields not explicitly redefined in this class
if field_name not in pydantic_annotations:
ann = field_info.annotation
# Only wrap in Optional if not already Optional
if ann is not None and not _is_optional_annotation(ann):
pydantic_annotations[field_name] = ann | None
else:
pydantic_annotations[field_name] = ann
# Set default to None if the field was required and
# not already defined in the current class
if field_name not in dict_for_pydantic:
# Copy the FieldInfo to preserve metadata like
# min_length, ge, etc.
if hasattr(field_info, "_copy"):
new_field_info = field_info._copy()
else:
new_field_info = copy.copy(field_info)
if new_field_info.is_required():
new_field_info.default = None
dict_for_pydantic[field_name] = new_field_info

dict_used = {
**dict_for_pydantic,
"__weakref__": None,
Expand Down
241 changes: 241 additions & 0 deletions tests/test_model_fields_optional.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
import pytest
from pydantic import ValidationError
from sqlmodel import Field, SQLModel
from sqlmodel._compat import SQLModelConfig


def test_model_fields_optional_basic(clear_sqlmodel):
"""Test that model_fields_optional='all' makes all inherited fields Optional
with a default of None."""

class HeroBase(SQLModel):
name: str
secret_name: str
age: int | None = None

class HeroUpdate(HeroBase, model_fields_optional="all"):
pass

# All fields should be optional (not required)
for field_info in HeroUpdate.model_fields.values():
assert not field_info.is_required()

# Should be able to create with no arguments
hero = HeroUpdate()
assert hero.name is None
assert hero.secret_name is None
assert hero.age is None


def test_model_fields_optional_partial_data(clear_sqlmodel):
"""Test creating an instance with only some fields set."""

class HeroBase(SQLModel):
name: str
secret_name: str
age: int | None = None

class HeroUpdate(HeroBase, model_fields_optional="all"):
pass

hero = HeroUpdate(name="Spider-Man")
assert hero.name == "Spider-Man"
assert hero.secret_name is None
assert hero.age is None


def test_model_fields_optional_exclude_unset(clear_sqlmodel):
"""Test that model_dump(exclude_unset=True) only includes explicitly set
fields."""

class HeroBase(SQLModel):
name: str
secret_name: str
age: int | None = None

class HeroUpdate(HeroBase, model_fields_optional="all"):
pass

hero = HeroUpdate(name="Spider-Man")
dumped = hero.model_dump(exclude_unset=True)
assert dumped == {"name": "Spider-Man"}


def test_model_fields_optional_override_field(clear_sqlmodel):
"""Test that explicitly redefined fields in the child class are not
overridden by model_fields_optional."""

class HeroBase(SQLModel):
name: str
secret_name: str
age: int | None = None

class HeroUpdate(HeroBase, model_fields_optional="all"):
name: str # Keep name required

# name should still be required
assert HeroUpdate.model_fields["name"].is_required()
# Other fields should be optional
assert not HeroUpdate.model_fields["secret_name"].is_required()
assert not HeroUpdate.model_fields["age"].is_required()

with pytest.raises(ValidationError):
HeroUpdate() # name is required

hero = HeroUpdate(name="Batman")
assert hero.name == "Batman"
assert hero.secret_name is None


def test_model_fields_optional_preserves_constraints(clear_sqlmodel):
"""Test that field constraints (min_length, ge, etc.) are preserved when
making fields optional."""

class HeroBase(SQLModel):
name: str = Field(min_length=1)
age: int | None = Field(default=None, ge=0)

class HeroUpdate(HeroBase, model_fields_optional="all"):
pass

# None should be valid for all fields
hero = HeroUpdate(name=None, age=None)
assert hero.name is None
assert hero.age is None

# Non-None values should still be validated
with pytest.raises(ValidationError):
HeroUpdate(name="") # min_length=1 violated

with pytest.raises(ValidationError):
HeroUpdate(age=-1) # ge=0 violated

# Valid non-None values should work
hero = HeroUpdate(name="X", age=5)
assert hero.name == "X"
assert hero.age == 5


def test_model_fields_optional_multiple_inheritance(clear_sqlmodel):
"""Test model_fields_optional with multiple levels of inheritance."""

class PersonBase(SQLModel):
first_name: str
last_name: str

class EmployeeBase(PersonBase):
employee_id: int
department: str

class EmployeeUpdate(EmployeeBase, model_fields_optional="all"):
pass

# All fields from all base classes should be optional
for field_info in EmployeeUpdate.model_fields.values():
assert not field_info.is_required()

employee = EmployeeUpdate(department="Engineering")
assert employee.department == "Engineering"
assert employee.first_name is None
assert employee.last_name is None
assert employee.employee_id is None


def test_model_fields_optional_via_model_config(clear_sqlmodel):
"""Test model_fields_optional via model_config dict."""

class HeroBase(SQLModel):
name: str
secret_name: str
age: int | None = None

class HeroUpdate(HeroBase):
model_config = SQLModelConfig(model_fields_optional="all")

# All fields should be optional
for field_info in HeroUpdate.model_fields.values():
assert not field_info.is_required()

hero = HeroUpdate()
assert hero.name is None
assert hero.secret_name is None
assert hero.age is None


def test_model_fields_optional_with_table_base(clear_sqlmodel):
"""Test that model_fields_optional works alongside table models."""

class HeroBase(SQLModel):
name: str
secret_name: str
age: int | None = None

class Hero(HeroBase, table=True):
id: int | None = Field(default=None, primary_key=True)

class HeroUpdate(HeroBase, model_fields_optional="all"):
pass

# Table model should still work normally
hero = Hero(name="Batman", secret_name="Bruce Wayne")
assert hero.name == "Batman"

# Update model should have all optional fields
update = HeroUpdate(name="Dark Knight")
assert update.name == "Dark Knight"
assert update.secret_name is None


def test_model_fields_optional_already_optional_fields(clear_sqlmodel):
"""Test that already-optional fields remain optional and keep their
defaults."""

class HeroBase(SQLModel):
name: str
nickname: str | None = "Unknown"
age: int | None = None

class HeroUpdate(HeroBase, model_fields_optional="all"):
pass

hero = HeroUpdate()
# name was required, should now be None
assert hero.name is None
# nickname had a default of "Unknown", should keep it
assert hero.nickname == "Unknown"
# age had a default of None, should stay None
assert hero.age is None


def test_model_fields_optional_model_validate(clear_sqlmodel):
"""Test that model_validate works correctly with model_fields_optional."""

class HeroBase(SQLModel):
name: str
secret_name: str
age: int | None = None

class HeroUpdate(HeroBase, model_fields_optional="all"):
pass

hero = HeroUpdate.model_validate({"name": "Spider-Man"})
assert hero.name == "Spider-Man"
assert hero.secret_name is None

hero2 = HeroUpdate.model_validate({})
assert hero2.name is None


def test_model_fields_optional_json_schema(clear_sqlmodel):
"""Test that JSON schema reflects optional fields."""

class HeroBase(SQLModel):
name: str
secret_name: str

class HeroUpdate(HeroBase, model_fields_optional="all"):
pass

schema = HeroUpdate.model_json_schema()
# No fields should be required in the schema
assert "required" not in schema or len(schema.get("required", [])) == 0
Loading