218 lines
6.8 KiB
Python
218 lines
6.8 KiB
Python
"""async_pymongo database cursor core"""
|
|
# Copyright (C) 2020 - 2023 UserbotIndo Team, <https://github.com/userbotindo.git>
|
|
# Copyright (C) 2023 Mayuri-Chan, <https://github.com/Mayuri-Chan.git>
|
|
#
|
|
# This program is free software: you can redistribute it and/or modify
|
|
# it under the terms of the GNU General Public License as published by
|
|
# the Free Software Foundation, either version 3 of the License, or
|
|
# (at your option) any later version.
|
|
#
|
|
# This program is distributed in the hope that it will be useful,
|
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
# GNU General Public License for more details.
|
|
#
|
|
# You should have received a copy of the GNU General Public License
|
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
|
|
import asyncio
|
|
import inspect
|
|
from functools import partial
|
|
from typing import (
|
|
TYPE_CHECKING,
|
|
Any,
|
|
Coroutine,
|
|
Deque,
|
|
Generic,
|
|
List,
|
|
Mapping,
|
|
Optional,
|
|
Tuple,
|
|
Union,
|
|
)
|
|
|
|
from pymongo.client_session import ClientSession
|
|
from pymongo.collection import Collection
|
|
from pymongo.cursor import _QUERY_OPTIONS, Cursor, RawBatchCursor
|
|
from pymongo.typings import _Address, _DocumentType
|
|
|
|
from async_pymongo.async_helper import run_sync
|
|
|
|
from .base import AsyncBase
|
|
from .errors import InvalidOperation
|
|
|
|
if TYPE_CHECKING:
|
|
from .collection import AsyncCollection
|
|
from .command_cursor import CommandCursor, _LatentCursor
|
|
|
|
|
|
class AsyncCursorBase(AsyncBase, Generic[_DocumentType]):
|
|
"""Base class for Cursor AsyncIOMongoDB instances
|
|
|
|
*DEPRECATED* methods are removed in this class.
|
|
|
|
:meth:`~each()` is removed because we can iterate directly this class,
|
|
And we now have :meth:`~to_list()` so yeah kinda useless
|
|
"""
|
|
|
|
collection: Optional[Union["AsyncCollection[_DocumentType]", Collection[_DocumentType]]]
|
|
dispatch: Union[
|
|
"_LatentCursor[_DocumentType]",
|
|
"CommandCursor[_DocumentType]",
|
|
Cursor[_DocumentType],
|
|
RawBatchCursor[_DocumentType],
|
|
]
|
|
loop: asyncio.AbstractEventLoop
|
|
|
|
def __init__(
|
|
self,
|
|
cursor: Union[
|
|
"_LatentCursor[_DocumentType]",
|
|
"CommandCursor[_DocumentType]",
|
|
Cursor[_DocumentType],
|
|
RawBatchCursor[_DocumentType],
|
|
],
|
|
collection: "Optional[AsyncCollection[_DocumentType]]" = None,
|
|
) -> None:
|
|
super().__init__(cursor)
|
|
|
|
if collection:
|
|
self.collection = collection
|
|
else:
|
|
self.collection = cursor.collection
|
|
self.started = False
|
|
self.closed = False
|
|
|
|
self.loop = asyncio.get_event_loop()
|
|
|
|
async def __aenter__(self) -> "AsyncCursorBase":
|
|
return self
|
|
|
|
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
|
if self.dispatch:
|
|
await self.close()
|
|
|
|
def __aiter__(self) -> "AsyncCursorBase":
|
|
return self
|
|
|
|
async def __anext__(self) -> Mapping[str, Any]:
|
|
return await self.next()
|
|
|
|
def _buffer_size(self) -> int:
|
|
return len(self._data())
|
|
|
|
def _query_flags(self) -> int:
|
|
raise NotImplementedError
|
|
|
|
def _data(self) -> Deque[Any]:
|
|
raise NotImplementedError
|
|
|
|
def _killed(self) -> bool:
|
|
raise NotImplementedError
|
|
|
|
def _get_more(self) -> Coroutine[Any, Any, int]:
|
|
if not self.alive:
|
|
raise InvalidOperation(
|
|
"Can't call get_more() on a AsyncCursor that has been" " exhausted or killed."
|
|
)
|
|
|
|
self.started = True
|
|
return self._refresh()
|
|
|
|
def _to_list(
|
|
self,
|
|
length: Optional[int],
|
|
the_list: List[Mapping[str, Any]],
|
|
future: asyncio.Future[List[Mapping[str, Any]]],
|
|
get_more_future: asyncio.Future[int],
|
|
) -> None:
|
|
# get_more_future is the result of self._get_more().
|
|
# future will be the result of the user's to_list() call.
|
|
try:
|
|
result = get_more_future.result()
|
|
# Return early if the task was cancelled.
|
|
if future.done():
|
|
return
|
|
|
|
if length is None:
|
|
n = result
|
|
else:
|
|
n = min(length - len(the_list), result)
|
|
|
|
i = 0
|
|
while i < n:
|
|
the_list.append(self._data().popleft())
|
|
i += 1
|
|
|
|
reached_length = length is not None and len(the_list) >= length
|
|
if reached_length or not self.alive:
|
|
future.set_result(the_list)
|
|
else:
|
|
new_future = self.loop.create_task(self._get_more())
|
|
new_future.add_done_callback(
|
|
partial(self.loop.call_soon_threadsafe, self._to_list, length, the_list, future)
|
|
)
|
|
except Exception as exc: # skipcq: PYL-W0703
|
|
if not future.done():
|
|
future.set_exception(exc)
|
|
|
|
async def _refresh(self) -> int:
|
|
return await run_sync(self.dispatch._refresh) # skipcq: PYL-W0212
|
|
|
|
def batch_size(self, batch_size: int) -> "AsyncCursorBase":
|
|
self.dispatch.batch_size(batch_size)
|
|
return self
|
|
|
|
async def close(self) -> None:
|
|
if not self.closed:
|
|
self.closed = True
|
|
await run_sync(self.dispatch.close)
|
|
|
|
async def next(self) -> Any:
|
|
if self.alive and (self._buffer_size() or await self._get_more()):
|
|
return await run_sync(next, self.dispatch)
|
|
raise StopAsyncIteration
|
|
|
|
def to_list(self, length: Optional[int] = None) -> asyncio.Future[List[Mapping[str, Any]]]:
|
|
if length is not None and length < 0:
|
|
raise ValueError("length must be non-negative")
|
|
|
|
if self._query_flags() & _QUERY_OPTIONS["tailable_cursor"]:
|
|
raise InvalidOperation("Can't call to_list on tailable cursor")
|
|
|
|
future = self.loop.create_future()
|
|
the_list: List[Mapping[str, Any]] = []
|
|
|
|
if not self.alive:
|
|
future.set_result(the_list)
|
|
return future
|
|
|
|
# Ignored the type since some commands are called from command_cursor
|
|
get_more_future: Union[asyncio.Future, asyncio.Task] = self._get_more() # type: ignore
|
|
if inspect.iscoroutine(get_more_future):
|
|
get_more_future = self.loop.create_task(get_more_future)
|
|
|
|
get_more_future.add_done_callback(
|
|
partial(self.loop.call_soon_threadsafe, self._to_list, length, the_list, future)
|
|
)
|
|
|
|
return future
|
|
|
|
@property
|
|
def address(self) -> Optional[Union[Tuple[str, int], _Address]]:
|
|
return self.dispatch.address
|
|
|
|
@property
|
|
def alive(self) -> bool:
|
|
if not self.dispatch:
|
|
return True
|
|
return self.dispatch.alive
|
|
|
|
@property
|
|
def cursor_id(self) -> Optional[int]:
|
|
return self.dispatch.cursor_id
|
|
|
|
@property
|
|
def session(self) -> Optional[ClientSession]:
|
|
return self.dispatch.session
|