From: Arto Jantunen <viiru@debian.org>
Date: Wed, 3 Apr 2024 13:50:25 +0300
Subject: Patch tests to accept different generated import order

The change is probably due to using a different SQLAlchemy version than the
testsuite expects.
---
 src/sqlacodegen/generators.py       |  2 +-
 tests/test_cli.py                   |  4 +-
 tests/test_generator_dataclass.py   | 16 +++-----
 tests/test_generator_declarative.py | 75 +++++++++++++------------------------
 4 files changed, 35 insertions(+), 62 deletions(-)

diff --git a/src/sqlacodegen/generators.py b/src/sqlacodegen/generators.py
index 21eadb6..d6d946a 100644
--- a/src/sqlacodegen/generators.py
+++ b/src/sqlacodegen/generators.py
@@ -1446,7 +1446,7 @@ class SQLModelGenerator(DeclarativeGenerator):
             kwargs["default"] = None
             python_type_name = f"Optional[{python_type_name}]"
 
-        rendered_column = self.render_column(column, True)
+        rendered_column = self.render_column(column, True, is_table=True)
         kwargs["sa_column"] = f"{rendered_column}"
         rendered_field = render_callable("Field", kwargs=kwargs)
         return f"{column_attr.name}: {python_type_name} = {rendered_field}"
diff --git a/tests/test_cli.py b/tests/test_cli.py
index 6a176d8..7d7536e 100644
--- a/tests/test_cli.py
+++ b/tests/test_cli.py
@@ -72,6 +72,7 @@ def test_cli_declarative(db_path: Path, tmp_path: Path) -> None:
         output_path.read_text()
         == """\
 from sqlalchemy import Integer, Text
+
 from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
 
 class Base(DeclarativeBase):
@@ -105,6 +106,7 @@ def test_cli_dataclass(db_path: Path, tmp_path: Path) -> None:
         output_path.read_text()
         == """\
 from sqlalchemy import Integer, Text
+
 from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, mapped_column
 
 class Base(MappedAsDataclass, DeclarativeBase):
@@ -137,9 +139,9 @@ def test_cli_sqlmodels(db_path: Path, tmp_path: Path) -> None:
     assert (
         output_path.read_text()
         == """\
+from sqlalchemy import Column, Integer, Text
 from typing import Optional
 
-from sqlalchemy import Column, Integer, Text
 from sqlmodel import Field, SQLModel
 
 class Foo(SQLModel, table=True):
diff --git a/tests/test_generator_dataclass.py b/tests/test_generator_dataclass.py
index ae7eab2..9bd02df 100644
--- a/tests/test_generator_dataclass.py
+++ b/tests/test_generator_dataclass.py
@@ -32,11 +32,10 @@ def test_basic_class(generator: CodeGenerator) -> None:
     validate_code(
         generator.generate(),
         """\
-        from typing import Optional
-
         from sqlalchemy import Integer, String
         from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, \
 mapped_column
+        from typing import Optional
 
         class Base(MappedAsDataclass, DeclarativeBase):
             pass
@@ -63,11 +62,10 @@ def test_mandatory_field_last(generator: CodeGenerator) -> None:
     validate_code(
         generator.generate(),
         """\
-        from typing import Optional
-
         from sqlalchemy import Integer, String, text
         from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, \
 mapped_column
+        from typing import Optional
 
         class Base(MappedAsDataclass, DeclarativeBase):
             pass
@@ -101,11 +99,10 @@ def test_onetomany_optional(generator: CodeGenerator) -> None:
     validate_code(
         generator.generate(),
         """\
-        from typing import List, Optional
-
         from sqlalchemy import ForeignKey, Integer
         from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, \
 mapped_column, relationship
+        from typing import List, Optional
 
         class Base(MappedAsDataclass, DeclarativeBase):
             pass
@@ -152,11 +149,10 @@ def test_manytomany(generator: CodeGenerator) -> None:
     validate_code(
         generator.generate(),
         """\
-        from typing import List
-
         from sqlalchemy import Column, ForeignKey, Integer, Table
         from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, \
 mapped_column, relationship
+        from typing import List
 
         class Base(MappedAsDataclass, DeclarativeBase):
             pass
@@ -208,11 +204,10 @@ def test_named_foreign_key_constraints(generator: CodeGenerator) -> None:
     validate_code(
         generator.generate(),
         """\
-        from typing import List, Optional
-
         from sqlalchemy import ForeignKeyConstraint, Integer
         from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, \
 mapped_column, relationship
+        from typing import List, Optional
 
         class Base(MappedAsDataclass, DeclarativeBase):
             pass
@@ -256,6 +251,7 @@ def test_uuid_type_annotation(generator: CodeGenerator) -> None:
         from sqlalchemy import UUID
         from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, \
 mapped_column
+
         import uuid
 
         class Base(MappedAsDataclass, DeclarativeBase):
diff --git a/tests/test_generator_declarative.py b/tests/test_generator_declarative.py
index d9bf7b5..735cea6 100644
--- a/tests/test_generator_declarative.py
+++ b/tests/test_generator_declarative.py
@@ -47,10 +47,9 @@ def test_indexes(generator: CodeGenerator) -> None:
     validate_code(
         generator.generate(),
         """\
-from typing import Optional
-
 from sqlalchemy import Index, Integer, String
 from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
+from typing import Optional
 
 class Base(DeclarativeBase):
     pass
@@ -84,10 +83,9 @@ def test_constraints(generator: CodeGenerator) -> None:
     validate_code(
         generator.generate(),
         """\
-from typing import Optional
-
 from sqlalchemy import CheckConstraint, Integer, UniqueConstraint
 from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
+from typing import Optional
 
 class Base(DeclarativeBase):
     pass
@@ -123,10 +121,9 @@ def test_onetomany(generator: CodeGenerator) -> None:
     validate_code(
         generator.generate(),
         """\
-from typing import List, Optional
-
 from sqlalchemy import ForeignKey, Integer
 from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
+from typing import List, Optional
 
 class Base(DeclarativeBase):
     pass
@@ -166,10 +163,9 @@ def test_onetomany_selfref(generator: CodeGenerator) -> None:
     validate_code(
         generator.generate(),
         """\
-from typing import List, Optional
-
 from sqlalchemy import ForeignKey, Integer
 from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
+from typing import List, Optional
 
 class Base(DeclarativeBase):
     pass
@@ -204,10 +200,9 @@ def test_onetomany_selfref_multi(generator: CodeGenerator) -> None:
     validate_code(
         generator.generate(),
         """\
-from typing import List, Optional
-
 from sqlalchemy import ForeignKey, Integer
 from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
+from typing import List, Optional
 
 class Base(DeclarativeBase):
     pass
@@ -258,10 +253,9 @@ def test_onetomany_composite(generator: CodeGenerator) -> None:
     validate_code(
         generator.generate(),
         """\
-from typing import List, Optional
-
 from sqlalchemy import ForeignKeyConstraint, Integer
 from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
+from typing import List, Optional
 
 class Base(DeclarativeBase):
     pass
@@ -314,10 +308,9 @@ def test_onetomany_multiref(generator: CodeGenerator) -> None:
     validate_code(
         generator.generate(),
         """\
-from typing import List, Optional
-
 from sqlalchemy import ForeignKey, Integer
 from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
+from typing import List, Optional
 
 class Base(DeclarativeBase):
     pass
@@ -369,10 +362,9 @@ def test_onetoone(generator: CodeGenerator) -> None:
     validate_code(
         generator.generate(),
         """\
-from typing import Optional
-
 from sqlalchemy import ForeignKey, Integer
 from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
+from typing import Optional
 
 class Base(DeclarativeBase):
     pass
@@ -413,10 +405,9 @@ def test_onetomany_noinflect(generator: CodeGenerator) -> None:
     validate_code(
         generator.generate(),
         """\
-from typing import List, Optional
-
 from sqlalchemy import ForeignKey, Integer
 from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
+from typing import List, Optional
 
 class Base(DeclarativeBase):
     pass
@@ -461,10 +452,9 @@ def test_onetomany_conflicting_column(generator: CodeGenerator) -> None:
     validate_code(
         generator.generate(),
         """\
-from typing import List, Optional
-
 from sqlalchemy import ForeignKey, Integer, Text
 from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
+from typing import List, Optional
 
 class Base(DeclarativeBase):
     pass
@@ -506,10 +496,9 @@ def test_onetomany_conflicting_relationship(generator: CodeGenerator) -> None:
     validate_code(
         generator.generate(),
         """\
-from typing import List, Optional
-
 from sqlalchemy import ForeignKey, Integer
 from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
+from typing import List, Optional
 
 class Base(DeclarativeBase):
     pass
@@ -555,10 +544,9 @@ def test_manytoone_nobidi(generator: CodeGenerator) -> None:
     validate_code(
         generator.generate(),
         """\
-from typing import Optional
-
 from sqlalchemy import ForeignKey, Integer
 from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
+from typing import Optional
 
 class Base(DeclarativeBase):
     pass
@@ -601,10 +589,9 @@ def test_manytomany(generator: CodeGenerator) -> None:
     validate_code(
         generator.generate(),
         """\
-from typing import List
-
 from sqlalchemy import Column, ForeignKey, Integer, Table
 from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
+from typing import List
 
 class Base(DeclarativeBase):
     pass
@@ -657,10 +644,9 @@ def test_manytomany_nobidi(generator: CodeGenerator) -> None:
     validate_code(
         generator.generate(),
         """\
-from typing import List
-
 from sqlalchemy import Column, ForeignKey, Integer, Table
 from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
+from typing import List
 
 class Base(DeclarativeBase):
     pass
@@ -705,10 +691,9 @@ def test_manytomany_selfref(generator: CodeGenerator) -> None:
     validate_code(
         generator.generate(),
         """\
-from typing import List
-
 from sqlalchemy import Column, ForeignKey, Integer, Table
 from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
+from typing import List
 
 class Base(DeclarativeBase):
     pass
@@ -773,10 +758,9 @@ def test_manytomany_composite(generator: CodeGenerator) -> None:
     validate_code(
         generator.generate(),
         """\
-from typing import List
-
 from sqlalchemy import Column, ForeignKeyConstraint, Integer, Table
 from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
+from typing import List
 
 class Base(DeclarativeBase):
     pass
@@ -843,10 +827,9 @@ def test_joined_inheritance(generator: CodeGenerator) -> None:
     validate_code(
         generator.generate(),
         """\
-from typing import Optional
-
 from sqlalchemy import ForeignKey, Integer
 from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
+from typing import Optional
 
 class Base(DeclarativeBase):
     pass
@@ -980,10 +963,9 @@ def test_use_inflect_plural(
     validate_code(
         generator.generate(),
         f"""\
-from typing import Optional
-
 from sqlalchemy import ForeignKey, Integer
 from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
+from typing import Optional
 
 class Base(DeclarativeBase):
     pass
@@ -1051,10 +1033,9 @@ def test_table_args_kwargs(generator: CodeGenerator) -> None:
     validate_code(
         generator.generate(),
         """\
-from typing import Optional
-
 from sqlalchemy import Index, Integer, String
 from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
+from typing import Optional
 
 class Base(DeclarativeBase):
     pass
@@ -1091,10 +1072,9 @@ def test_foreign_key_schema(generator: CodeGenerator) -> None:
     validate_code(
         generator.generate(),
         """\
-from typing import List, Optional
-
 from sqlalchemy import ForeignKey, Integer
 from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
+from typing import List, Optional
 
 class Base(DeclarativeBase):
     pass
@@ -1136,10 +1116,9 @@ def test_invalid_attribute_names(generator: CodeGenerator) -> None:
     validate_code(
         generator.generate(),
         """\
-from typing import Optional
-
 from sqlalchemy import Integer
 from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
+from typing import Optional
 
 class Base(DeclarativeBase):
     pass
@@ -1325,10 +1304,9 @@ def test_metadata_column(generator: CodeGenerator) -> None:
     validate_code(
         generator.generate(),
         """\
-from typing import Optional
-
 from sqlalchemy import Integer, String
 from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
+from typing import Optional
 
 class Base(DeclarativeBase):
     pass
@@ -1401,11 +1379,10 @@ def test_named_constraints(generator: CodeGenerator) -> None:
     validate_code(
         generator.generate(),
         """\
-from typing import Optional
-
 from sqlalchemy import CheckConstraint, Integer, PrimaryKeyConstraint, \
 String, UniqueConstraint
 from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
+from typing import Optional
 
 class Base(DeclarativeBase):
     pass
@@ -1444,10 +1421,9 @@ def test_named_foreign_key_constraints(generator: CodeGenerator) -> None:
     validate_code(
         generator.generate(),
         """\
-from typing import List, Optional
-
 from sqlalchemy import ForeignKeyConstraint, Integer
 from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
+from typing import List, Optional
 
 class Base(DeclarativeBase):
     pass
@@ -1491,10 +1467,9 @@ def test_colname_import_conflict(generator: CodeGenerator) -> None:
     validate_code(
         generator.generate(),
         """\
-from typing import Optional
-
 from sqlalchemy import Integer, String, text
 from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
+from typing import Optional
 
 class Base(DeclarativeBase):
     pass
