async_pymongo/async_pymongo/client_session.py

202 lines
6.7 KiB
Python

"""async_pymongo database session"""
# Copyright (C) 2020 - 2023 UserbotIndo Team, <https://github.com/userbotindo.git>
# Copyright (C) 2023 Mayuri-Chan, <https://github.com/Mayuri-Chan.git>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from contextlib import asynccontextmanager
from time import monotonic as monotonic_time
from typing import (
TYPE_CHECKING,
Any,
AsyncGenerator,
Callable,
Coroutine,
Mapping,
Optional,
)
from bson.timestamp import Timestamp
from pymongo.client_session import ClientSession, SessionOptions
from pymongo.read_concern import ReadConcern
from pymongo.write_concern import WriteConcern
from async_pymongo.async_helper import run_sync
from .base import AsyncBase
from .errors import OperationFailure, PyMongoError
from .typings import ReadPreferences, Results
if TYPE_CHECKING:
from .client import AsyncClient
class AsyncClientSession(AsyncBase):
"""AsyncIO :obj:`~ClientSession`
*DEPRECATED* methods are removed in this class.
"""
_client: "AsyncClient"
dispatch: ClientSession
def __init__(self, client: "AsyncClient", dispatch: ClientSession) -> None:
self._client = client
# Propagate initialization to base
super().__init__(dispatch)
async def __aenter__(self) -> "AsyncClientSession":
return self
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
await run_sync(self.dispatch.__exit__, exc_type, exc_val, exc_tb)
def __enter__(self) -> None:
raise RuntimeError("Use 'async with' not just 'with'")
async def abort_transaction(self) -> None:
return await run_sync(self.dispatch.abort_transaction)
async def commit_transaction(self) -> None:
return await run_sync(self.dispatch.commit_transaction)
async def end_session(self) -> None:
return await run_sync(self.dispatch.end_session)
@asynccontextmanager
async def start_transaction(
self,
*,
read_concern: Optional[ReadConcern] = None,
write_concern: Optional[WriteConcern] = None,
read_preference: Optional[ReadPreferences] = None,
max_commit_time_ms: Optional[int] = None,
) -> AsyncGenerator["AsyncClientSession", None]:
await run_sync(
self.dispatch.start_transaction,
read_concern=read_concern,
write_concern=write_concern,
read_preference=read_preference,
max_commit_time_ms=max_commit_time_ms,
)
try:
yield self
except Exception: # skipcq: PYL-W0703
if self.in_transaction:
await self.abort_transaction()
else:
if self.in_transaction:
await self.commit_transaction()
async def with_transaction(
self,
callback: Callable[["AsyncClientSession"], Coroutine[Any, Any, Results]],
*,
read_concern: Optional[ReadConcern] = None,
write_concern: Optional[WriteConcern] = None,
read_preference: Optional[ReadPreferences] = None,
max_commit_time_ms: Optional[int] = None,
) -> Results:
# 99% Of this code from motor's lib
def _within_time_limit(s: float) -> bool:
return monotonic_time() - s < 120
def _max_time_expired_error(exc: PyMongoError) -> bool:
return isinstance(exc, OperationFailure) and exc.code == 50
start_time = monotonic_time()
while True:
async with self.start_transaction(
read_concern=read_concern,
write_concern=write_concern,
read_preference=read_preference,
max_commit_time_ms=max_commit_time_ms,
):
try:
ret = await callback(self)
except Exception as exc:
if self.in_transaction:
await self.abort_transaction()
if (
isinstance(exc, PyMongoError)
and exc.has_error_label("TransientTransactionError")
and _within_time_limit(start_time)
):
# Retry the entire transaction.
continue
raise
if not self.in_transaction:
# Assume callback intentionally ended the transaction.
return ret
while True:
try:
await self.commit_transaction()
except PyMongoError as exc:
if (
exc.has_error_label("UnknownTransactionCommitResult")
and _within_time_limit(start_time)
and not _max_time_expired_error(exc)
):
# Retry the commit.
continue
if exc.has_error_label("TransientTransactionError") and _within_time_limit(
start_time
):
# Retry the entire transaction.
break
raise
# Commit succeeded.
return ret
def advance_cluster_time(self, cluster_time: Mapping[str, Any]) -> None:
self.dispatch.advance_cluster_time(cluster_time=cluster_time)
def advance_operation_time(self, operation_time: Timestamp) -> None:
self.dispatch.advance_operation_time(operation_time=operation_time)
@property
def client(self) -> "AsyncClient":
return self._client
@property
def cluster_time(self) -> Optional[Mapping[str, Any]]:
return self.dispatch.cluster_time
@property
def has_ended(self) -> bool:
return self.dispatch.has_ended
@property
def in_transaction(self) -> bool:
return self.dispatch.in_transaction
@property
def operation_time(self) -> Optional[Timestamp]:
return self.dispatch.operation_time
@property
def options(self) -> SessionOptions:
return self.dispatch.options
@property
def session_id(self) -> Mapping[str, Any]:
return self.dispatch.session_id