From 03eb2ac85a619e68976d30a6b4ef9560d4960f28 Mon Sep 17 00:00:00 2001
From: Tim Burke <tim.burke@gmail.com>
Date: Tue, 29 Nov 2022 09:36:52 -0800
Subject: [PATCH] Fix DB tests on py311

Change-Id: Ic2695e2e836da5607f4f5c016c660496e2821e07
---

diff --git a/swift/common/db.py b/swift/common/db.py
index 3e42d95..cb5b5cd 100644
--- a/swift/common/db.py
+++ b/swift/common/db.py
@@ -130,6 +130,7 @@
 
 class GreenDBConnection(sqlite3.Connection):
     """SQLite DB Connection handler that plays well with eventlet."""
+    __slots__ = ('timeout', 'db_file')
 
     def __init__(self, database, timeout=None, *args, **kwargs):
         if timeout is None:
@@ -143,6 +144,11 @@
             cls = GreenDBCursor
         return sqlite3.Connection.cursor(self, cls)
 
+    def execute(self, *args, **kwargs):
+        return _db_timeout(
+            self.timeout, self.db_file, lambda: sqlite3.Connection.execute(
+                self, *args, **kwargs))
+
     def commit(self):
         return _db_timeout(
             self.timeout, self.db_file,
@@ -151,6 +157,7 @@
 
 class GreenDBCursor(sqlite3.Cursor):
     """SQLite Cursor handler that plays well with eventlet."""
+    __slots__ = ('timeout', 'db_file')
 
     def __init__(self, *args, **kwargs):
         self.timeout = args[0].timeout
@@ -581,16 +588,15 @@
         conn.execute('BEGIN IMMEDIATE')
         try:
             yield True
-        except (Exception, Timeout):
-            pass
-        try:
-            conn.execute('ROLLBACK')
-            conn.isolation_level = orig_isolation_level
-            self.conn = conn
-        except (Exception, Timeout):
-            logging.exception(
-                _('Broker error trying to rollback locked connection'))
-            conn.close()
+        finally:
+            try:
+                conn.execute('ROLLBACK')
+                conn.isolation_level = orig_isolation_level
+                self.conn = conn
+            except (Exception, Timeout):
+                logging.exception(
+                    _('Broker error trying to rollback locked connection'))
+                conn.close()
 
     def _new_db_id(self):
         device_name = os.path.basename(self.get_device_path())
diff --git a/test/unit/account/test_backend.py b/test/unit/account/test_backend.py
index 9e6cb55..c5e7e3a 100644
--- a/test/unit/account/test_backend.py
+++ b/test/unit/account/test_backend.py
@@ -26,7 +26,6 @@
 from shutil import rmtree
 import sqlite3
 import itertools
-from contextlib import contextmanager
 import random
 import mock
 import base64
@@ -37,7 +36,8 @@
 from swift.account.backend import AccountBroker
 from swift.common.utils import Timestamp
 from test.unit import patch_policies, with_tempdir, make_timestamp_iter
-from swift.common.db import DatabaseConnectionError, TombstoneReclaimer
+from swift.common.db import DatabaseConnectionError, TombstoneReclaimer, \
+    GreenDBConnection
 from swift.common.request_helpers import get_reserved_name
 from swift.common.storage_policy import StoragePolicy, POLICIES
 from swift.common.utils import md5
@@ -1494,32 +1494,21 @@
         broker.put_container('c', next(self.ts).internal, 0, 0, 0,
                              POLICIES.default.idx)
 
-        real_get = broker.get
-        called = []
-
-        @contextmanager
-        def mock_get():
-            with real_get() as conn:
-
-                def mock_executescript(script):
-                    if called:
-                        raise Exception('kaboom!')
-                    called.append(script)
-
-                conn.executescript = mock_executescript
-                yield conn
-
-        broker.get = mock_get
-
         try:
-            broker._commit_puts()
+            orig_execute_script = GreenDBConnection.executescript
+            with mock.patch.object(
+                GreenDBConnection, 'executescript',
+                side_effect=[orig_execute_script, Exception('kaboom!')],
+            ) as mock_executescript:
+                broker._commit_puts()
         except Exception:
             pass
         else:
             self.fail('mock exception was not raised')
 
-        self.assertEqual(len(called), 1)
-        self.assertIn('CREATE TABLE policy_stat', called[0])
+        self.assertEqual(len(mock_executescript.mock_calls), 2)
+        self.assertIn('CREATE TABLE policy_stat',
+                      mock_executescript.mock_calls[0][1][0])
 
         # nothing was committed
         broker = AccountBroker(db_path, account='a')
diff --git a/test/unit/common/test_db.py b/test/unit/common/test_db.py
index 8a3e11a..8e762d0 100644
--- a/test/unit/common/test_db.py
+++ b/test/unit/common/test_db.py
@@ -150,18 +150,19 @@
     def test_execute_when_locked(self):
         # This test is dependent on the code under test calling execute and
         # commit as sqlite3.Cursor.execute in a subclass.
-        class InterceptCursor(sqlite3.Cursor):
+        class InterceptConnection(sqlite3.Connection):
             pass
         db_error = sqlite3.OperationalError('database is locked')
-        InterceptCursor.execute = MagicMock(side_effect=db_error)
-        with patch('sqlite3.Cursor', new=InterceptCursor):
+        InterceptConnection.execute = MagicMock(side_effect=db_error)
+        with patch('sqlite3.Connection', new=InterceptConnection):
             conn = sqlite3.connect(':memory:', check_same_thread=False,
                                    factory=GreenDBConnection, timeout=0.1)
             self.assertRaises(Timeout, conn.execute, 'select 1')
-            self.assertTrue(InterceptCursor.execute.called)
-            self.assertEqual(InterceptCursor.execute.call_args_list,
-                             list((InterceptCursor.execute.call_args,) *
-                                  InterceptCursor.execute.call_count))
+            self.assertTrue(InterceptConnection.execute.called)
+            self.assertEqual(InterceptConnection.execute.call_args_list,
+                             list((InterceptConnection.execute.call_args,) *
+                                  InterceptConnection.execute.call_count))
+            self.assertGreater(InterceptConnection.execute.call_count, 1)
 
     def text_commit_when_locked(self):
         # This test is dependent on the code under test calling commit and
