From 414f3966da75c17e655b88e591b31be386eaae7e Mon Sep 17 00:00:00 2001 From: profitroll Date: Fri, 24 Nov 2023 23:52:50 +0100 Subject: [PATCH] Attempt to temporarily fix #2 --- src/pyrmv/classes/board.py | 11 +++++-- src/pyrmv/classes/client.py | 7 ++--- src/pyrmv/classes/journey.py | 7 ++++- src/pyrmv/utility/__init__.py | 1 + src/pyrmv/utility/journey_ref_converter.py | 26 ++++++++++++++++ tests/conftest.py | 2 +- tests/test_client.py | 35 ++++++++++++++-------- 7 files changed, 68 insertions(+), 21 deletions(-) create mode 100644 src/pyrmv/utility/journey_ref_converter.py diff --git a/src/pyrmv/classes/board.py b/src/pyrmv/classes/board.py index b5362f0..a69c9ac 100644 --- a/src/pyrmv/classes/board.py +++ b/src/pyrmv/classes/board.py @@ -2,11 +2,15 @@ from datetime import datetime from typing import Any, Mapping from pyrmv.classes.message import Message +from pyrmv.utility import ref_upgrade class LineArrival: def __init__(self, data: Mapping[str, Any], client, retrieve_stops: bool = True): - self.journey = client.journey_detail(data["JourneyDetailRef"]["ref"]) + # Upgrade is temporarily used due to RMV API mismatch + # self.journey = client.journey_detail(data["JourneyDetailRef"]["ref"]) + self.journey = client.journey_detail(ref_upgrade(data["JourneyDetailRef"]["ref"])) + self.status = data["JourneyStatus"] self.messages = [] self.name = data["name"] @@ -40,7 +44,10 @@ class LineArrival: class LineDeparture: def __init__(self, data: Mapping[str, Any], client, retrieve_stops: bool = True): - self.journey = client.journey_detail(data["JourneyDetailRef"]["ref"]) + # Upgrade is temporarily used due to RMV API mismatch + # self.journey = client.journey_detail(data["JourneyDetailRef"]["ref"]) + self.journey = client.journey_detail(ref_upgrade(data["JourneyDetailRef"]["ref"])) + self.status = data["JourneyStatus"] self.messages = [] self.name = data["name"] diff --git a/src/pyrmv/classes/client.py b/src/pyrmv/classes/client.py index ef103b0..1c04c1f 100644 --- a/src/pyrmv/classes/client.py +++ b/src/pyrmv/classes/client.py @@ -121,7 +121,7 @@ class Client: * BoardArrival: Instance of `BoardArrival` object. """ - if isinstance(direction, Stop) or isinstance(direction, StopTrip): + if isinstance(direction, (Stop, StopTrip)): direction = direction.id board_raw = raw_board_arrival( @@ -827,10 +827,7 @@ class Client: * List[Trip]: List of `Trip` objects. Empty list if none found. """ - if real_time_mode == None: - real_time_mode = None - else: - real_time_mode = real_time_mode.code + real_time_mode = None if real_time_mode is None else real_time_mode.code if isinstance(context, Trip): context = context.ctx_recon diff --git a/src/pyrmv/classes/journey.py b/src/pyrmv/classes/journey.py index 1cd2f37..32aa574 100644 --- a/src/pyrmv/classes/journey.py +++ b/src/pyrmv/classes/journey.py @@ -2,6 +2,7 @@ from typing import Any, Mapping from pyrmv.classes.message import Message from pyrmv.classes.stop import Stop +from pyrmv.utility import ref_upgrade class Journey: @@ -9,7 +10,11 @@ class Journey: def __init__(self, data: Mapping[str, Any]): self.stops = [] - self.ref = data["ref"] + + # Upgrade is temporarily used due to RMV API mismatch + # self.ref = data["ref"] + self.ref = ref_upgrade(data["ref"]) + self.direction = data["Directions"]["Direction"][0]["value"] self.direction_flag = data["Directions"]["Direction"][0]["flag"] self.stops.extend(Stop(stop) for stop in data["Stops"]["Stop"]) diff --git a/src/pyrmv/utility/__init__.py b/src/pyrmv/utility/__init__.py index 1a20d53..c88aa0e 100644 --- a/src/pyrmv/utility/__init__.py +++ b/src/pyrmv/utility/__init__.py @@ -1,2 +1,3 @@ from .find_exception import find_exception +from .journey_ref_converter import ref_upgrade from .weekdays_bitmask import weekdays_bitmask diff --git a/src/pyrmv/utility/journey_ref_converter.py b/src/pyrmv/utility/journey_ref_converter.py new file mode 100644 index 0000000..2449a2a --- /dev/null +++ b/src/pyrmv/utility/journey_ref_converter.py @@ -0,0 +1,26 @@ +def ref_upgrade(ref: str) -> str: + """This function converts older journey refs to the newer ones. + + ### WARNING + This function will be deprecated as soon as RMV updates their API + + ### Args: + * ref (`str`): Old ref like this one: `2|#VN#1#ST#1700765441#PI#0#ZI#160749#TA#0#DA#241123#1S#3004646#1T#2228#LS#3006907#LT#2354#PU#80#RT#1#CA#S30#ZE#S1#ZB# S1#PC#3#FR#3004646#FT#2228#TO#3006907#TT#2354#` + + ### Raises: + * `KeyError`: Some required keys are not found in the ref provided + + ### Returns: + * `str`: Ref of the new type + """ + + items = "|".join(ref.split("|")[1:]).strip("#").split("#") + result = {items[i]: items[i + 1] for i in range(0, len(items), 2)} + + for required in ["VN", "ZI", "TA", "PU"]: + if required not in result: + raise KeyError( + f"Required key {required} in the old journey ref is not found during conversion to the newer journey ref" + ) + + return "|".join([result["VN"], result["ZI"], result["TA"], result["PU"]]) diff --git a/tests/conftest.py b/tests/conftest.py index 53a6c4c..fbe4293 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -23,7 +23,7 @@ def sample_stop_id() -> str: @pytest.fixture() def sample_journey_id() -> str: - return "2|#VN#1#ST#1664906549#PI#0#ZI#12709#TA#0#DA#61022#1S#3008007#1T#1248#LS#3008043#LT#1323#PU#80#RT#1#CA#1aE#ZE#101#ZB#Bus 101 #PC#6#FR#3008007#FT#1248#TO#3008043#TT#1323#" + return "1|12709|0|80" @pytest.fixture() diff --git a/tests/test_client.py b/tests/test_client.py index 4f5a7b1..8577e4e 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -26,15 +26,14 @@ def test_him_search(api_client: Client): ) -# Does not work as it should yet -# def test_journey_detail(api_client: Client, sample_journey_id: str): -# assert ( -# api_client.journey_detail( -# sample_journey_id, -# real_time_mode=enums.RealTimeMode.FULL, -# ), -# Journey, -# ) +def test_journey_detail(api_client: Client, sample_journey_id: str): + assert ( + api_client.journey_detail( + sample_journey_id, + real_time_mode=enums.RealTimeMode.FULL, + ), + Journey, + ) def test_stop_by_coords(api_client: Client, sample_origin: List[str]): @@ -62,9 +61,21 @@ def test_trip_find( ) -# Does not work as it should yet -# def test_trip_recon(api_client: Client): -# assert api_client.trip_recon() +def test_trip_recon( + api_client: Client, sample_origin: List[str], sample_destination: List[float] +): + assert isinstance( + api_client.trip_recon( + api_client.trip_find( + origin_coord_lat=sample_origin[0], + origin_coord_lon=sample_origin[1], + destination_coord_lat=sample_destination[0], + destination_coord_lon=sample_destination[1], + messages=True, + )[0], + )[0], + Trip, + ) def test_stop_by_name(api_client: Client):