1- import argparse
1+ import os
2+ import unittest
23import subprocess
34import difflib
4- import sys
5+
6+ from cms .conf import config
7+ from cms .db .drop import drop_db
8+ from cms .db .init import init_db
9+ from cms .db .session import custom_psycopg2_connection
510
611"""
712Compare the DB schema obtained from upgrading an older version's database using
1722to the first line thing is ALTER TABLE ADD CONSTRAINT, in which the constraint
1823name is on the second line. So we move the constraint name up to the first
1924line.)
20- """
2125
26+ To update the files after a new release:
2227
23- def split_schemma (schema : str ):
28+ cmsInitDB
29+ pg_dump --schema-only >schema_vX.Y.sql
30+
31+ and replace update_from_vX.Y.sql with a blank file.
32+ """
33+
34+ def split_schema (schema : str ) -> list [list [str ]]:
2435 statements : list [list [str ]] = []
2536 cur_statement : list [str ] = []
2637 for line in schema .splitlines ():
@@ -34,9 +45,10 @@ def split_schemma(schema: str):
3445 return statements
3546
3647
37- def normalize_stmt (statement : list [str ]):
48+ def normalize_stmt (statement : list [str ]) -> list [ str ] :
3849 if statement [0 ].startswith ("CREATE TABLE " ):
3950 # normalize order of columns by sorting the arguments to CREATE TABLE.
51+
4052 assert statement [- 1 ] == ");"
4153 # add missing trailing comma on the last column.
4254 assert not statement [- 2 ].endswith ("," )
@@ -56,12 +68,12 @@ def normalize_stmt(statement: list[str]):
5668 return statement
5769
5870
59- def is_create_enum (line : str ):
71+ def is_create_enum (line : str ) -> bool :
6072 return line .startswith ("CREATE TYPE " ) and line .endswith (" AS ENUM (" )
6173
6274
63- def compare_schemas (updated_schema : list [list [str ]], fresh_schema : list [list [str ]]):
64- ok = True
75+ def compare_schemas (updated_schema : list [list [str ]], fresh_schema : list [list [str ]]) -> str :
76+ errors : list [ str ] = []
6577
6678 updated_map : dict [str , list [str ]] = {}
6779 for stmt in map (normalize_stmt , updated_schema ):
@@ -75,8 +87,7 @@ def compare_schemas(updated_schema: list[list[str]], fresh_schema: list[list[str
7587
7688 for updated_stmt in updated_map .values ():
7789 if updated_stmt [0 ] not in fresh_map :
78- print ("Updated schema contains extra statement:" , * updated_stmt , sep = "\n " )
79- ok = False
90+ errors += ["Updated schema contains extra statement:" , * updated_stmt ]
8091 else :
8192 fresh_stmt = fresh_map [updated_stmt [0 ]]
8293 if is_create_enum (updated_stmt [0 ]):
@@ -86,87 +97,64 @@ def compare_schemas(updated_schema: list[list[str]], fresh_schema: list[list[str
8697 }
8798 fresh_values = {x .removesuffix ("," ).strip () for x in fresh_stmt [1 :- 1 ]}
8899 if not fresh_values .issubset (updated_values ):
89- print ( "Updated schema is missing enum value(s):" )
90- print ( "Updated:\n " + " \n " . join ( updated_stmt ))
91- print ( "Fresh:\n " + " \n " . join ( fresh_stmt ))
100+ errors += [ "Updated schema is missing enum value(s):" ]
101+ errors += [ "Updated:" ] + [ " " + x for x in updated_stmt ]
102+ errors += [ "Fresh:" ] + [ " " + x for x in fresh_stmt ]
92103 else :
93104 # Other statements must match exactly (in normalized form)
94105 if updated_stmt != fresh_stmt :
95- ok = False
96106 differ = difflib .Differ ()
97107 cmp = differ .compare (
98108 [x + "\n " for x in updated_stmt ], [x + "\n " for x in fresh_stmt ]
99109 )
100- print ( "Statement differs between updated and fresh schema:" )
101- print ( "" .join (cmp ))
110+ errors += [ "Statement differs between updated and fresh schema:" ]
111+ errors += [ "" .join (cmp ). strip ()]
102112
103113 for fresh_stmt in fresh_map .values ():
104114 if fresh_stmt [0 ] not in updated_map :
105- print ("Fresh schema contains extra statement:" , * fresh_stmt , sep = "\n " )
106- ok = False
115+ errors += ["Fresh schema contains extra statement:" , * fresh_stmt ]
107116 # if it exists, then it was already checked earlier
108117 # print('\n'.join(updated_map.keys()))
109- return ok
110-
118+ return '\n ' .join (errors )
111119
112- def get_updated_schema (user , host , name , schema_sql , updater_sql ):
113- args = [f"--username={ user } " , f"--host={ host } " , name ]
114- psql_flags = ["--quiet" , "--set=ON_ERROR_STOP=1" ]
115- subprocess .run (["dropdb" , "--if-exists" , * args ], check = True )
116- subprocess .run (["createdb" , * args ], check = True )
117- subprocess .run (
118- ["psql" , * args , * psql_flags , f"--file={ schema_sql } " ],
119- check = True ,
120- stdout = subprocess .PIPE ,
121- )
122- subprocess .run (
123- ["psql" , * args , * psql_flags , f"--file={ updater_sql } " ],
124- check = True ,
125- )
120+ def run_pg_dump () -> str :
121+ db_url = config .database .url
122+ db_url = db_url .replace ("postgresql+psycopg2://" , "postgresql://" )
126123 result = subprocess .run (
127- ["pg_dump" , "--schema-only" , * args ],
124+ ["pg_dump" , "--schema-only" , "--dbname" , db_url ],
128125 check = True ,
129126 text = True ,
130127 stdout = subprocess .PIPE ,
131128 )
132129 return result .stdout
133130
134-
135- def get_fresh_schema (user , host , name ):
136- args = [f"--username={ user } " , f"--host={ host } " , name ]
137- subprocess .run (["dropdb" , "--if-exists" , * args ], check = True )
138- subprocess .run (["createdb" , * args ], check = True )
139- subprocess .run (["cmsInitDB" ], check = True )
140- result = subprocess .run (
141- ["pg_dump" , "--schema-only" , * args ],
142- check = True ,
143- text = True ,
144- stdout = subprocess .PIPE ,
145- )
146- return result .stdout
147-
148-
149- def main ():
150- parser = argparse .ArgumentParser ()
151- parser .add_argument ("--user" , required = True )
152- parser .add_argument ("--host" , required = True )
153- parser .add_argument ("--name" , required = True )
154- parser .add_argument ("--schema_sql" , required = True )
155- parser .add_argument ("--updater_sql" , required = True )
156- args = parser .parse_args ()
157- print ("Checking schema updater..." )
158- updated_schema = split_schemma (
159- get_updated_schema (
160- args .user , args .host , args .name , args .schema_sql , args .updater_sql
161- )
162- )
163- fresh_schema = split_schemma (get_fresh_schema (args .user , args .host , args .name ))
164- if compare_schemas (updated_schema , fresh_schema ):
165- print ("All good, updater works" )
166- sys .exit (0 )
167- else :
168- sys .exit (1 )
169-
170-
171- if __name__ == "__main__" :
172- main ()
131+ def get_updated_schema (schema_file : str , updater_file : str ) -> str :
132+ drop_db ()
133+ schema_sql = open (schema_file ).read ()
134+ updater_sql = open (updater_file ).read ()
135+ # We need to do this in two separate connections, since the schema_sql sets
136+ # some connection properties which we don't want.
137+ for sql in [schema_sql , updater_sql ]:
138+ conn = custom_psycopg2_connection ()
139+ cursor = conn .cursor ()
140+ cursor .execute (sql )
141+ conn .commit ()
142+ conn .close ()
143+
144+ return run_pg_dump ()
145+
146+ def get_fresh_schema ():
147+ drop_db ()
148+ init_db ()
149+ return run_pg_dump ()
150+
151+ class TestSchemaDiff (unittest .TestCase ):
152+ def test_schema_diff (self ):
153+ dirname = os .path .dirname (__file__ )
154+ schema_file = os .path .join (dirname , "schema_v1.5.sql" )
155+ updater_file = os .path .join (dirname , "../../cmscontrib/updaters/update_from_1.5.sql" )
156+ updated_schema = split_schema (get_updated_schema (schema_file , updater_file ))
157+ fresh_schema = split_schema (get_fresh_schema ())
158+ errors = compare_schemas (updated_schema , fresh_schema )
159+ self .longMessage = False
160+ self .assertTrue (errors == "" , errors )
0 commit comments