1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186
|
import json
import os
import sqlite3
from datetime import date
from typing import Any
import anysqlite
import pytest
def format_value(value: Any, col_name: str, col_type: str) -> str:
"""Format a value for display based on its type and column name."""
if value is None:
return "NULL"
# Handle BLOB columns
if col_type.upper() == "BLOB":
if isinstance(value, bytes):
# Try to decode as UTF-8 string first
try:
decoded = value.decode("utf-8")
# Check if it looks like JSON
if decoded.strip().startswith("{") or decoded.strip().startswith("["):
try:
parsed = json.loads(decoded)
return f"(JSON) {json.dumps(parsed, indent=2)}"
except json.JSONDecodeError:
pass
# Show string if it's printable
if all(32 <= ord(c) <= 126 or c in "\n\r\t" for c in decoded):
return f"(str) '{decoded}'"
except UnicodeDecodeError:
pass
# Show hex representation for binary data
hex_str = value.hex()
if len(hex_str) > 64:
return f"(bytes) 0x{hex_str[:60]}... ({len(value)} bytes)"
return f"(bytes) 0x{hex_str} ({len(value)} bytes)"
return repr(value)
# Handle timestamps - ONLY show date, not the raw timestamp
if col_name.endswith("_at") and isinstance(value, (int, float)):
try:
dt = date.fromtimestamp(value)
return dt.isoformat() # Changed: removed the timestamp prefix
except (ValueError, OSError):
return str(value)
# Handle TEXT columns
if col_type.upper() == "TEXT":
return f"'{value}'"
# Handle other types
return str(value)
def print_sqlite_state(conn: sqlite3.Connection) -> str:
"""
Print all tables and their rows in a pretty format suitable for inline snapshots.
Args:
conn: SQLite database connection
Returns:
Formatted string representation of the database state
"""
cursor = conn.cursor()
# Get all table names
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name")
tables = [row[0] for row in cursor.fetchall()]
output_lines = []
output_lines.append("=" * 80)
output_lines.append("DATABASE SNAPSHOT")
output_lines.append("=" * 80)
for table_name in tables:
# Get column information
cursor.execute(f"PRAGMA table_info({table_name})")
columns = cursor.fetchall()
column_names = [col[1] for col in columns]
column_types = {col[1]: col[2] for col in columns}
# Get all rows
cursor.execute(f"SELECT * FROM {table_name}")
rows = cursor.fetchall()
output_lines.append("")
output_lines.append(f"TABLE: {table_name}")
output_lines.append("-" * 80)
output_lines.append(f"Rows: {len(rows)}")
output_lines.append("")
if not rows:
output_lines.append(" (empty)")
continue
# Format each row
for idx, row in enumerate(rows, 1):
output_lines.append(f" Row {idx}:")
for col_name, value in zip(column_names, row):
col_type = column_types[col_name]
formatted_value = format_value(value, col_name, col_type)
output_lines.append(f" {col_name:15} = {formatted_value}")
if idx < len(rows):
output_lines.append("")
output_lines.append("")
output_lines.append("=" * 80)
result = "\n".join(output_lines)
return result
async def aprint_sqlite_state(conn: anysqlite.Connection) -> str:
"""
Print all tables and their rows in a pretty format suitable for inline snapshots.
Args:
conn: SQLite database connection
Returns:
Formatted string representation of the database state
"""
cursor = await conn.cursor()
# Get all table names
await cursor.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name")
tables = [row[0] for row in await cursor.fetchall()]
output_lines = []
output_lines.append("=" * 80)
output_lines.append("DATABASE SNAPSHOT")
output_lines.append("=" * 80)
for table_name in tables:
# Get column information
await cursor.execute(f"PRAGMA table_info({table_name})")
columns = await cursor.fetchall()
column_names = [col[1] for col in columns]
column_types = {col[1]: col[2] for col in columns}
# Get all rows
await cursor.execute(f"SELECT * FROM {table_name}")
rows = await cursor.fetchall()
output_lines.append("")
output_lines.append(f"TABLE: {table_name}")
output_lines.append("-" * 80)
output_lines.append(f"Rows: {len(rows)}")
output_lines.append("")
if not rows:
output_lines.append(" (empty)")
continue
# Format each row
for idx, row in enumerate(rows, 1):
output_lines.append(f" Row {idx}:")
for col_name, value in zip(column_names, row):
col_type = column_types[col_name]
formatted_value = format_value(value, col_name, col_type)
output_lines.append(f" {col_name:15} = {formatted_value}")
if idx < len(rows):
output_lines.append("")
output_lines.append("")
output_lines.append("=" * 80)
result = "\n".join(output_lines)
return result
@pytest.fixture()
def use_temp_dir(tmpdir):
cur_dir = os.getcwd()
os.chdir(tmpdir)
yield
os.chdir(cur_dir)
|