From e29ca6c29a253f209ee794796ca5f1a92b2b9354 Mon Sep 17 00:00:00 2001 From: Praneeth Devunuri Date: Mon, 30 Jun 2025 00:01:27 -0500 Subject: [PATCH] Enhance date filter functionality in `get_gtfs_segments` and `get_bus_feed`. Added support for single and multiple date inputs, including validation and error handling for invalid dates. Updated documentation and tests to reflect these changes. --- README.md | 45 +- docs/usage.md | 57 ++- gtfs_segments/__init__.py | 73 +-- gtfs_segments/gtfs_segments.py | 847 ++++++++++++++++---------------- gtfs_segments/partridge_func.py | 291 +++++++---- requirements.txt | 21 +- test/test_gtfs_segments.py | 286 +++++++++++ 7 files changed, 1064 insertions(+), 556 deletions(-) diff --git a/README.md b/README.md index db08908..4632e44 100644 --- a/README.md +++ b/README.md @@ -208,11 +208,50 @@ segments_df = get_gtfs_segments("path_to_gtfs_zip_file", parallel = True) Alternatively, filter a specific agency by passing `agency_id` as a string or multiple agencies as list ["SFMTA",] -``` +```python segments_df = get_gtfs_segments("path_to_gtfs_zip_file",agency_id = "SFMTA") segments_df ``` +### Analyze Specific Dates + +By default, `get_gtfs_segments` analyzes the busiest day in the GTFS schedule. You can specify a particular date/dates for analysis using the `date` parameter: + +#### **Single Date Analysis** +```python +# Analyze a specific day using string format +segments_df = get_gtfs_segments("path_to_gtfs_zip_file", date="20240317") # YYYYMMDD format + +# Using datetime.date object +from datetime import date +segments_df = get_gtfs_segments("path_to_gtfs_zip_file", date=date(2024, 3, 17)) + +# Combined with other parameters +segments_df = get_gtfs_segments( + "path_to_gtfs_zip_file", + agency_id="SFMTA", + date="20240317", + max_spacing=1000, + parallel=True +) +``` + +#### **Multiple Date Analysis** +```python +# Multiple dates using strings +segments_df = get_gtfs_segments( + "path_to_gtfs_zip_file", + date=["20220315", "20220316", "20220317"] # Tue, Wed, Thu +) + +# Mixed date types (strings and date objects) +from datetime import date +segments_df = get_gtfs_segments( + "path_to_gtfs_zip_file", + date=["20220315", date(2022, 3, 16)] +) + +
data
@@ -223,7 +262,7 @@ Table generated by gtfs-segments using data from San Francisco’s Muni system. 3. `stop_id2`: The identifier of the segment's ending stop. 4. `route_id`: The same route ID listed in the agency's routes.txt file. 5. `direction_id`: The route's direction identifier. -6. `traversals`: The number of times the indicated route traverses the segment during the "measurement interval." The "measurement interval" chosen is the busiest day in the GTFS schedule: the day which has the most bus services running. +6. `traversals`: The number of times the indicated route traverses the segment during the "measurement interval." The "measurement interval" can be a single date, multiple dates (aggregated), or the busiest day in the GTFS schedule (default behavior). 7. `distance`: The length of the bus segment in meters. 8. `geometry`: The segment's LINESTRING (a format for encoding geographic paths) written in WGS84 (EPGS:4326) coordinates, that is, unprojected longitude-latitude pairs, as used in GTFS. 9. `traversal_time`: The time (in seconds) that it takes for the bus to traverse the segment. @@ -454,3 +493,5 @@ Project Link: [https://github.com/UTEL-UIUC/gtfs_segments](https://github.com/UT [issues-url]: https://github.com/UTEL-UIUC/gtfs_segments/issues [license-shield]: https://img.shields.io/github/license/UTEL-UIUC/gtfs_segments.svg?style=for-the-badge [license-url]: https://github.com/UTEL-UIUC/gtfs_segments/blob/master/LICENSE + + diff --git a/docs/usage.md b/docs/usage.md index 031e324..b5f3359 100644 --- a/docs/usage.md +++ b/docs/usage.md @@ -46,10 +46,58 @@ from gtfs_segments import get_gtfs_segments segments_df = get_gtfs_segments("path_to_gtfs_zip_file") ``` Alternatively filter a specific agency by passing `agency_id` as a string or multiple agencies as list ["SFMTA",] -``` +```python segments_df = get_gtfs_segments("path_to_gtfs_zip_file",agency_id = "SFMTA") segments_df ``` + +### Analyze Specific Dates + +By default, `get_gtfs_segments` analyzes the busiest day in the GTFS schedule. The enhanced date functionality now supports comprehensive single and multi-date analysis with robust error handling: + +#### **Single Date Analysis** +```python +# Analyze a specific Sunday using string format +segments_df = get_gtfs_segments("path_to_gtfs_zip_file", date="20240317") # YYYYMMDD format + +# Using datetime.date object +from datetime import date +segments_df = get_gtfs_segments("path_to_gtfs_zip_file", date=date(2024, 3, 17)) + +# Combined with other parameters +segments_df = get_gtfs_segments( + "path_to_gtfs_zip_file", + agency_id="SFMTA", + date="20240317", + max_spacing=1000, + parallel=True +) +``` + +#### **Multiple Date Analysis** +```python +# Multiple dates using strings +segments_df = get_gtfs_segments( + "path_to_gtfs_zip_file", + date=["20220315", "20220316", "20220317"] # Tue, Wed, Thu +) + +# Mixed date types (strings and date objects) +from datetime import date +segments_df = get_gtfs_segments( + "path_to_gtfs_zip_file", + date=["20220315", date(2022, 3, 16)] +) + + + +#### **Performance & Technical Notes** + +- **Service Aggregation**: When multiple dates are provided, services from all dates are combined and aggregated +- **Memory Usage**: Larger date ranges will use more memory for processing +- **Processing Time**: Multiple dates may increase processing time, but `parallel=True` helps optimize performance +- **Backward Compatibility**: All existing functionality remains unchanged; new parameters are optional +- **Date Validation**: Invalid dates are validated against the GTFS calendar and provide helpful error messages with available date ranges
data
@@ -60,7 +108,7 @@ Table generated by gtfs-segments using data from the San Francisco’s Muni syst 3. `stop_id2`: The identifier of the segment's ending stop. 4. `route_id`: The same route ID listed in the agency's routes.txt file. 5. `direction_id`: The route's direction identifier. -6. `traversals`: The number of times the indicated route traverses the segment during the "measurement interval." The "measurement interval" chosen is the busiest day in the GTFS schedule: the day which has the most bus services running. +6. `traversals`: The number of times the indicated route traverses the segment during the "measurement interval." The "measurement interval" can be a single date, multiple dates (aggregated), or the busiest day in the GTFS schedule (default behavior). 7. `distance`: The length of the segment in meters. 8. `geometry`: The segment's LINESTRING (a format for encoding geographic paths). All geometries are re-projected onto Mercator (EPSG:4326/WGS84) to maintain consistency. @@ -103,7 +151,7 @@ summary_stats(segments_df,max_spacing = 3000,export = True,file_path = "summary. ## Get Route Summary Stats ```python from gtfs_segments import get_route_stats,get_bus_feed -_,feed = get_bus_feed('path_to_gtfs.zip') +feed = get_bus_feed('path_to_gtfs.zip') get_route_stats(feed) ``` Here each row contains the following columns: @@ -134,3 +182,6 @@ export_segments(segments_df,'filename', output_format ='csv',geometry = False) ```

(back to top)

+ + + diff --git a/gtfs_segments/__init__.py b/gtfs_segments/__init__.py index f197d21..008cd71 100644 --- a/gtfs_segments/__init__.py +++ b/gtfs_segments/__init__.py @@ -1,34 +1,39 @@ -""" -The gtfs_segments package main init file. -""" -import importlib.metadata -from .geom_utils import view_heatmap, view_spacings, view_spacings_interactive -from .gtfs_segments import get_gtfs_segments, pipeline_gtfs, process_feed -from .mobility import ( - download_latest_data, - fetch_gtfs_source, - summary_stats_mobility, -) -from .partridge_func import get_bus_feed -from .route_stats import get_route_stats -from .utils import export_segments, plot_hist, process, summary_stats - -__version__ = importlib.metadata.version("gtfs_segments") -__all__ = [ - "__version__", - "get_gtfs_segments", - "pipeline_gtfs", - "process_feed", - "export_segments", - "plot_hist", - "fetch_gtfs_source", - "summary_stats", - "process", - "view_spacings", - "view_spacings_interactive", - "view_heatmap", - "summary_stats_mobility", - "download_latest_data", - "get_route_stats", - "get_bus_feed", -] +""" +The gtfs_segments package main init file. +""" +import importlib.metadata +from .geom_utils import view_heatmap, view_spacings, view_spacings_interactive +from .gtfs_segments import get_gtfs_segments, pipeline_gtfs, process_feed +from .mobility import ( + download_latest_data, + fetch_gtfs_source, + summary_stats_mobility, +) +from .partridge_func import get_bus_feed +from .route_stats import get_route_stats +from .utils import export_segments, plot_hist, process, summary_stats + +try: + __version__ = importlib.metadata.version("gtfs_segments") +except importlib.metadata.PackageNotFoundError: + # Fallback version for development/testing when package is not installed + __version__ = "dev" + +__all__ = [ + "__version__", + "get_gtfs_segments", + "pipeline_gtfs", + "process_feed", + "export_segments", + "plot_hist", + "fetch_gtfs_source", + "summary_stats", + "process", + "view_spacings", + "view_spacings_interactive", + "view_heatmap", + "summary_stats_mobility", + "download_latest_data", + "get_route_stats", + "get_bus_feed", +] diff --git a/gtfs_segments/gtfs_segments.py b/gtfs_segments/gtfs_segments.py index fbb9957..3281d3d 100644 --- a/gtfs_segments/gtfs_segments.py +++ b/gtfs_segments/gtfs_segments.py @@ -1,421 +1,426 @@ -import os -from typing import List, Optional, Set - -import geopandas as gpd -import numpy as np -import pandas as pd -from shapely.geometry import LineString - -from .geom_utils import ( - get_zone_epsg, - make_gdf, - nearest_points, - nearest_points_parallel, - ret_high_res_shape, -) -from .mobility import summary_stats_mobility -from .partridge_func import get_bus_feed -from .partridge_mod.gtfs import Feed -from .utils import download_write_file, export_segments, failed_pipeline, plot_hist - - -def merge_trip_geom(trip_df: pd.DataFrame, shape_df: gpd.GeoDataFrame) -> gpd.GeoDataFrame: - """ - It takes a dataframe of trips and a dataframe of shapes, and returns a geodataframe of trips with - the geometry of the shapes - - Args: - trip_df: a dataframe of trips - shape_df: a GeoDataFrame of the shapes.txt file - - Returns: - A GeoDataFrame - """ - trips_with_no_shape_id = list(trip_df[trip_df["shape_id"].isna()].trip_id) - if len(trips_with_no_shape_id) > 0: - print("Excluding Trips with no shape_id:", trips_with_no_shape_id) - trip_df = trip_df[~trip_df["trip_id"].isin(trips_with_no_shape_id)] - - non_existent_shape_id = set(trip_df["shape_id"]) - set(shape_df["shape_id"]) - if len(non_existent_shape_id) > 0: - trips_with_no_corresponding_shape = list(trip_df[trip_df["shape_id"].isin(non_existent_shape_id)].trip_id) - print("Excluding Trips with non-existent shape_ids in shapes.txt:", trips_with_no_corresponding_shape) - trip_df = trip_df[~trip_df["shape_id"].isin(non_existent_shape_id)] - - # `direction_id` and `shape_id` are optional - if "direction_id" in trip_df.columns: - # Check is direction_ids are listed as null - if trip_df["direction_id"].isnull().sum() == 0: - grp = trip_df.groupby(["route_id", "shape_id", "direction_id"]) - else: - grp = trip_df.groupby(["route_id", "shape_id"]) - else: - grp = trip_df.groupby(["route_id", "shape_id"]) - trip_df = grp.first().reset_index() - trip_df["traversals"] = grp.count().reset_index(drop=True)["trip_id"] - subset_list = np.array( - ["route_id", "trip_id", "shape_id", "service_id", "direction_id", "traversals"] - ) - col_subset = subset_list[np.in1d(subset_list, trip_df.columns)] - trip_df = trip_df[col_subset] - trip_df = trip_df.dropna(how="all", axis=1) - trip_df = shape_df.merge(trip_df, on="shape_id", how="left") - return make_gdf(trip_df) - - -def make_segments_unique(df: gpd.GeoDataFrame, traversal_threshold: int = 1) -> gpd.GeoDataFrame: - # Compute the number of unique rounded distances for each route_id and segment_id - unique_counts = df.groupby(["route_id", "segment_id"])["distance"].apply( - lambda x: x.round().nunique() - ) - - # Filter rows where unique count is greater than 1 - filtered_df = df[ - df.set_index(["route_id", "segment_id"]).index.isin(unique_counts[unique_counts > 1].index) - ].copy() - - # Create a segment modification function - def modify_segment(segment_id: str, count: int) -> str: - seg_split = str(segment_id).split("-") - return seg_split[0] + "-" + seg_split[1] + "-" + str(count + 1) - - # Apply the modification function to the segment_id - filtered_df["modification"] = filtered_df.groupby(["route_id", "segment_id"]).cumcount() - filtered_df["segment_id"] = filtered_df.apply( - lambda row: modify_segment(row["segment_id"], row["modification"]) - if row["modification"] != 0 - else row["segment_id"], - axis=1, - ) - - # Merge the modified segments back into the original DataFrame - df = pd.concat([df[~df.index.isin(filtered_df.index)], filtered_df], ignore_index=True) - - # Aggregate traversals and filter by traversal threshold - grp_again = df.groupby(["route_id", "segment_id"]) - df = grp_again.first().reset_index() - df["traversals"] = grp_again["traversals"].sum().values - df = df[df.traversals > traversal_threshold].reset_index(drop=True) - return make_gdf(df) - - -def filter_stop_df(stop_df: pd.DataFrame, trip_ids: Set, stop_loc_df: pd.DataFrame) -> pd.DataFrame: - """ - It takes a dataframe of stops and a list of trip IDs and returns a dataframe of stops that are in - the list of trip IDs - - Args: - stop_df: the dataframe of all stops - trip_ids: a list of trip_ids that you want to filter the stop_df by - - Returns: - A dataframe with the trip_id, s top_id, and stop_sequence for the trips in the trip_ids list. - """ - missing_stop_locs = set(stop_df.stop_id) - set(stop_loc_df.stop_id) - if len(missing_stop_locs) > 0: - print("Missing stop locations for:", missing_stop_locs) - missing_trips = stop_df[stop_df.stop_id.isin(missing_stop_locs)].trip_id.unique() - for trip in missing_trips: - trip_ids.discard(trip) - print( - "Removed the trip_id:", trip, "as stop locations are missing for stops in the trip" - ) - # Filter the stop_df to only include the trip_ids in the trip_ids list - stop_df = stop_df[stop_df.trip_id.isin(trip_ids)].reset_index(drop=True) - stop_df = stop_df.sort_values(["trip_id", "stop_sequence"]).reset_index(drop=True) - stop_df["main_index"] = stop_df.index - stop_df_grp = stop_df.groupby("trip_id") - drop_inds = [] - # To eliminate deadheads - if "pickup_type" in stop_df.columns: - grp_f = stop_df_grp.first() - drop_inds.append(grp_f.loc[grp_f["pickup_type"] == 1, "main_index"]) - if "drop_off_type" in stop_df.columns: - grp_l = stop_df_grp.last() - drop_inds.append( - grp_l.loc[grp_l["drop_off_type"] == 1, "main_index"] - ) # Fixed the variable name from grp_f to grp_l - if len(drop_inds) > 0 and len(drop_inds[0]) > 0: - stop_df = stop_df[~stop_df["main_index"].isin(drop_inds)].reset_index(drop=True) - stop_df = stop_df[["trip_id", "stop_id", "stop_sequence", "arrival_time"]] - - stop_df = stop_df.sort_values(["trip_id", "stop_sequence"]).reset_index(drop=True) - return stop_df - - -def merge_stop_geom(stop_df: pd.DataFrame, stop_loc_df: pd.DataFrame) -> gpd.GeoDataFrame: - """ - > Merge the stop_loc_df with the stop_df, and then convert the result to a GeoDataFrame - - Args: - stop_df: a dataframe of stops - stop_loc_df: a GeoDataFrame of the stops - - Returns: - A GeoDataFrame - """ - stop_df["start"] = stop_df.copy().merge(stop_loc_df, how="left", on="stop_id")["geometry"] - return stop_df - - -def create_segments(stop_df: gpd.GeoDataFrame, parallel: bool = False) -> pd.DataFrame: - """ - This function creates segments between stops based on their proximity and returns a GeoDataFrame. - - Args: - stop_df: A pandas DataFrame containing information about stops on a transit network, including - their stop_id, coordinates, and trip_id. - - Returns: - a GeoDataFrame with segments created from the input stop_df. - """ - if parallel: - stop_df = nearest_points_parallel(stop_df) - else: - stop_df = nearest_points(stop_df) - stop_df = stop_df.rename({"stop_id": "stop_id1", "arrival_time": "arrival_time1"}, axis=1) - grp = ( - pd.DataFrame(stop_df).groupby("trip_id", group_keys=False).shift(-1).reset_index(drop=True) - ) - stop_df[["stop_id2", "end", "snap_end_id", "arrival_time2"]] = grp[ - ["stop_id1", "start", "snap_start_id", "arrival_time1"] - ] - stop_df["segment_id"] = stop_df.apply( - lambda row: str(row["stop_id1"]) + "-" + str(row["stop_id2"]) + "-1", axis=1 - ) - stop_df = stop_df.dropna().reset_index(drop=True) - stop_df.snap_end_id = stop_df.snap_end_id.astype(int) - stop_df = stop_df[stop_df["snap_end_id"] > stop_df["snap_start_id"]].reset_index(drop=True) - stop_df["geometry"] = stop_df.apply( - lambda row: LineString( - row["geometry"].coords[row["snap_start_id"] : row["snap_end_id"] + 1] - ), - axis=1, - ) - return stop_df - - -def process_feed_stops(feed: Feed) -> gpd.GeoDataFrame: - """ - It takes a GTFS feed, merges the trip and shape data, filters the stop_times data to only include - the trips that are in the feed, merges the stop_times data with the stop data, creates a segment for - each stop pair, gets the EPSG zone for the feed, creates a GeoDataFrame, and calculates the length - of each segment - - Args: - feed: a GTFS feed object - max_spacing: the maximum distance between stops in meters. If a stop is more than this distance - from the previous stop, it will be dropped. - - Returns: - A GeoDataFrame with the following columns: - """ - trip_df = merge_trip_geom(feed.trips, feed.shapes) - trip_ids = set(trip_df.trip_id.unique()) - stop_loc_df = feed.stops[["stop_id", "geometry"]] - stop_df = filter_stop_df(feed.stop_times, trip_ids, stop_loc_df) - stop_df = merge_stop_geom(stop_df, stop_loc_df) - stop_df = stop_df.merge(trip_df, on="trip_id", how="left") - stops = stop_df.groupby("shape_id").count().reset_index()["geometry"] - stop_df = stop_df.groupby("shape_id").first().reset_index() - stop_df["n_stops"] = stops - epsg_zone = get_zone_epsg(stop_df) - if epsg_zone is not None: - stop_df["distance"] = stop_df.geometry.to_crs(epsg_zone).length - stop_df["mean_distance"] = stop_df["distance"] / stop_df["n_stops"] - return make_gdf(stop_df) - - -def process_feed( - feed: Feed, parallel: bool = False, max_spacing: Optional[float] = None -) -> gpd.GeoDataFrame: - """ - The function `process_feed` takes a feed and optional maximum spacing as input, performs various - data processing and filtering operations on the feed, and returns a GeoDataFrame containing the - processed data. - - Args: - feed: The `feed` parameter is a data structure that contains information about a transit network. - It likely includes data such as shapes (geometric representations of routes), trips (sequences of - stops), stop times (arrival and departure times at stops), and stops (locations of stops). - [Optional] max_spacing: The `max_spacing` parameter is an optional parameter that specifies the maximum - distance between stops. If provided, the function will filter out stops that are farther apart than - the specified maximum spacing. - - Returns: - A GeoDataFrame containing information about the stops and segments in the feed with segments smaller than the max_spacing values. - """ - # Set a Spatial Resolution and increase the resolution of the shapes - # shapes = ret_high_res_shape_parallel(feed.shapes, spat_res=5) - ## Note: Currently, the parallel version of the function ret_high_res_shape_parallel is not working as expected and is slower than the non-parallel version - shapes = ret_high_res_shape(feed.shapes, feed.trips, spat_res=5) - trip_df = merge_trip_geom(feed.trips, shapes) - trip_ids = set(trip_df.trip_id.unique()) - stop_loc_df = feed.stops[["stop_id", "geometry"]] - stop_df = filter_stop_df(feed.stop_times, trip_ids, stop_loc_df) - stop_df = merge_stop_geom(stop_df, stop_loc_df) - stop_df = stop_df.merge(trip_df, on="trip_id", how="left") - stop_df = create_segments(stop_df, parallel=parallel) - stop_df = make_gdf(stop_df) - epsg_zone = get_zone_epsg(stop_df) - if epsg_zone is not None: - stop_df["distance"] = stop_df.set_geometry("geometry").to_crs(epsg_zone).geometry.length - stop_df["distance"] = stop_df["distance"].round(2) # round to 2 decimal places - stop_df["traversal_time"] = (stop_df["arrival_time2"] - stop_df["arrival_time1"]).astype( - "float" - ) - stop_df["speed"] = stop_df["distance"].div(stop_df["traversal_time"]) - stop_df = make_segments_unique(stop_df, traversal_threshold=0) - subset_list = np.array( - [ - "segment_id", - "route_id", - "direction_id", - "trip_id", - "traversals", - "distance", - "stop_id1", - "stop_id2", - "traversal_time", - "speed", - "geometry", - ] - ) - col_subset = subset_list[np.in1d(subset_list, stop_df.columns)] - stop_df = stop_df[col_subset] - if max_spacing is not None: - stop_df = stop_df[stop_df["distance"] <= max_spacing] - return make_gdf(stop_df) - - -def inspect_feed(feed: Feed) -> str: - """ - It checks to see if the feed has any bus routes and if it has a `shape_id` column in the `trips` - table - - Args: - feed: The feed object that you want to inspect. - - Returns: - A message - """ - message = "Valid GTFS Feed" - if len(feed.stop_times) == 0: - message = "No Bus Routes in " - if "shape_id" not in feed.trips.columns: - message = "Missing `shape_id` column in " - return message - - -def get_gtfs_segments( - path: str, - agency_id: Optional[str] = None, - threshold: Optional[int] = 1, - max_spacing: Optional[float] = None, - parallel: bool = False, -) -> gpd.GeoDataFrame: - """ - The function `get_gtfs_segments` takes a path to a GTFS feed file, an optional agency name, a - threshold value, and an optional maximum spacing value, and returns processed GTFS segments. - - Args: - path: The path parameter is the file path to the GTFS (General Transit Feed Specification) data. - This is the data format used by public transportation agencies to provide schedule and geographic - information about their services. - [Optional] agency_id: The agency_id of the transit agency for which you want to retrieve the bus feed. If this - parameter is not provided, the function will retrieve the bus feed for all transit agencies. You can pass - a list of agency_ids to retrieve the bus feed for multiple transit agencies. - [Optional] threshold: The threshold parameter is used to filter out bus trips that have fewer stops than the - specified threshold. Trips with fewer stops than the threshold will be excluded from the result. - Defaults to 1 - [Optional] max_spacing: The `max_spacing` parameter is used to specify the maximum distance between two - consecutive stops in a segment. If the distance between two stops exceeds the `max_spacing` value, - the segment is split into multiple segments. - - Returns: - A GeoDataFrame containing information about the stops and segments in the feed with segments - smaller than the max_spacing values. Each row contains the following columns: - - segment_id: the segment's identifier, produced by gtfs-segments - - stop_id1: The `stop_id` identifier of the segment's beginning stop. - The identifier is the same one the agency has chosen in the stops.txt file of its GTFS package. - - stop_id2: The `stop_id` identifier of the segment's ending stop. - - route_id: The same route ID listed in the agency's routes.txt file. - - direction_id: The route's direction identifier. - - traversals: The number of times the indicated route traverses the segment during the "measurement interval." - The "measurement interval" chosen is the busiest day in the GTFS schedule: the day which has the most bus services running. - - distance: The length of the segment in meters. - - geometry: The segment's LINESTRING (a format for encoding geographic paths). - All geometries are re-projected onto Mercator (EPSG:4326/WGS84) to maintain consistency. - """ - feed = get_bus_feed(path, agency_id=agency_id, threshold=threshold, parallel=parallel) - df = process_feed(feed, parallel=parallel) - if max_spacing is not None: - print("Using max_spacing {:.0f} to filter segments".format(max_spacing)) - df = df[df["distance"] <= max_spacing] - return df - - -def pipeline_gtfs(filename: str, url: str, bounds: List, max_spacing: float) -> str: - """ - It takes a GTFS file, downloads it, reads it, processes it, and then outputs a bunch of files. - - Let's go through the function step by step. - - First, we define the function and give it a name. We also give it a few arguments: - - - filename: the name of the file we want to save the output to. - - url: the url of the GTFS file we want to download. - - bounds: the bounding box of the area we want to analyze. - - max_spacing: the maximum spacing we want to analyze. - - We then create a folder to save the output to. - - Next, we download the GTFS file and save it to the folder we just created. - - Then, we read the GTFS file using the `get_bus_feed` function. - - Args: - filename: the name of the file you want to save the output to - url: the url of the GTFS file - bounds: the bounding box of the area you want to analyze. This is in the format - [min_lat,min_lon,max_lat,max_lon] - max_spacing: The maximum distance between stops that you want to consider. - - Returns: - Success or Failure of the pipeline - """ - folder_path = os.path.join("output_files", filename) - gtfs_file_loc = download_write_file(url, folder_path) - - # read file using GTFS Fucntions - feed = get_bus_feed(gtfs_file_loc) - # Remove Null entries - message = inspect_feed(feed) - if message != "Valid GTFS Feed": - return failed_pipeline(message, filename, folder_path) - - df = process_feed(feed) - df_sub = df[df["distance"] < 3000].copy().reset_index(drop=True) - if len(df_sub) == 0: - return failed_pipeline("Only Long Bus Routes in ", filename, folder_path) - # Output files and Stats - summary_stats_mobility(df, folder_path, filename, url, bounds, max_spacing, export=True) - - plot_hist( - df, - file_path=os.path.join(folder_path, "spacings.png"), - title=filename.split(".")[0], - max_spacing=max_spacing, - save_fig=True, - ) - export_segments( - df, os.path.join(folder_path, "geojson"), output_format="geojson", geometry=True - ) - export_segments( - df, - os.path.join(folder_path, "spacings_with_geometry"), - output_format="csv", - geometry=True, - ) - export_segments(df, os.path.join(folder_path, "spacings"), output_format="csv", geometry=False) - return "Success for " + filename +import os +from datetime import date as date_type +from typing import List, Optional, Set, Union + +import geopandas as gpd +import numpy as np +import pandas as pd +from shapely.geometry import LineString + +from .geom_utils import ( + get_zone_epsg, + make_gdf, + nearest_points, + nearest_points_parallel, + ret_high_res_shape, +) +from .mobility import summary_stats_mobility +from .partridge_func import get_bus_feed +from .partridge_mod.gtfs import Feed +from .utils import download_write_file, export_segments, failed_pipeline, plot_hist + + +def merge_trip_geom(trip_df: pd.DataFrame, shape_df: gpd.GeoDataFrame) -> gpd.GeoDataFrame: + """ + It takes a dataframe of trips and a dataframe of shapes, and returns a geodataframe of trips with + the geometry of the shapes + + Args: + trip_df: a dataframe of trips + shape_df: a GeoDataFrame of the shapes.txt file + + Returns: + A GeoDataFrame + """ + trips_with_no_shape_id = list(trip_df[trip_df["shape_id"].isna()].trip_id) + if len(trips_with_no_shape_id) > 0: + print("Excluding Trips with no shape_id:", trips_with_no_shape_id) + trip_df = trip_df[~trip_df["trip_id"].isin(trips_with_no_shape_id)] + + non_existent_shape_id = set(trip_df["shape_id"]) - set(shape_df["shape_id"]) + if len(non_existent_shape_id) > 0: + trips_with_no_corresponding_shape = list(trip_df[trip_df["shape_id"].isin(non_existent_shape_id)].trip_id) + print("Excluding Trips with non-existent shape_ids in shapes.txt:", trips_with_no_corresponding_shape) + trip_df = trip_df[~trip_df["shape_id"].isin(non_existent_shape_id)] + + # `direction_id` and `shape_id` are optional + if "direction_id" in trip_df.columns: + # Check is direction_ids are listed as null + if trip_df["direction_id"].isnull().sum() == 0: + grp = trip_df.groupby(["route_id", "shape_id", "direction_id"]) + else: + grp = trip_df.groupby(["route_id", "shape_id"]) + else: + grp = trip_df.groupby(["route_id", "shape_id"]) + trip_df = grp.first().reset_index() + trip_df["traversals"] = grp.count().reset_index(drop=True)["trip_id"] + subset_list = np.array( + ["route_id", "trip_id", "shape_id", "service_id", "direction_id", "traversals"] + ) + col_subset = subset_list[np.in1d(subset_list, trip_df.columns)] + trip_df = trip_df[col_subset] + trip_df = trip_df.dropna(how="all", axis=1) + trip_df = shape_df.merge(trip_df, on="shape_id", how="left") + return make_gdf(trip_df) + + +def make_segments_unique(df: gpd.GeoDataFrame, traversal_threshold: int = 1) -> gpd.GeoDataFrame: + # Compute the number of unique rounded distances for each route_id and segment_id + unique_counts = df.groupby(["route_id", "segment_id"])["distance"].apply( + lambda x: x.round().nunique() + ) + + # Filter rows where unique count is greater than 1 + filtered_df = df[ + df.set_index(["route_id", "segment_id"]).index.isin(unique_counts[unique_counts > 1].index) + ].copy() + + # Create a segment modification function + def modify_segment(segment_id: str, count: int) -> str: + seg_split = str(segment_id).split("-") + return seg_split[0] + "-" + seg_split[1] + "-" + str(count + 1) + + # Apply the modification function to the segment_id + filtered_df["modification"] = filtered_df.groupby(["route_id", "segment_id"]).cumcount() + filtered_df["segment_id"] = filtered_df.apply( + lambda row: modify_segment(row["segment_id"], row["modification"]) + if row["modification"] != 0 + else row["segment_id"], + axis=1, + ) + + # Merge the modified segments back into the original DataFrame + df = pd.concat([df[~df.index.isin(filtered_df.index)], filtered_df], ignore_index=True) + + # Aggregate traversals and filter by traversal threshold + grp_again = df.groupby(["route_id", "segment_id"]) + df = grp_again.first().reset_index() + df["traversals"] = grp_again["traversals"].sum().values + df = df[df.traversals > traversal_threshold].reset_index(drop=True) + return make_gdf(df) + + +def filter_stop_df(stop_df: pd.DataFrame, trip_ids: Set, stop_loc_df: pd.DataFrame) -> pd.DataFrame: + """ + It takes a dataframe of stops and a list of trip IDs and returns a dataframe of stops that are in + the list of trip IDs + + Args: + stop_df: the dataframe of all stops + trip_ids: a list of trip_ids that you want to filter the stop_df by + + Returns: + A dataframe with the trip_id, s top_id, and stop_sequence for the trips in the trip_ids list. + """ + missing_stop_locs = set(stop_df.stop_id) - set(stop_loc_df.stop_id) + if len(missing_stop_locs) > 0: + print("Missing stop locations for:", missing_stop_locs) + missing_trips = stop_df[stop_df.stop_id.isin(missing_stop_locs)].trip_id.unique() + for trip in missing_trips: + trip_ids.discard(trip) + print( + "Removed the trip_id:", trip, "as stop locations are missing for stops in the trip" + ) + # Filter the stop_df to only include the trip_ids in the trip_ids list + stop_df = stop_df[stop_df.trip_id.isin(trip_ids)].reset_index(drop=True) + stop_df = stop_df.sort_values(["trip_id", "stop_sequence"]).reset_index(drop=True) + stop_df["main_index"] = stop_df.index + stop_df_grp = stop_df.groupby("trip_id") + drop_inds = [] + # To eliminate deadheads + if "pickup_type" in stop_df.columns: + grp_f = stop_df_grp.first() + drop_inds.append(grp_f.loc[grp_f["pickup_type"] == 1, "main_index"]) + if "drop_off_type" in stop_df.columns: + grp_l = stop_df_grp.last() + drop_inds.append( + grp_l.loc[grp_l["drop_off_type"] == 1, "main_index"] + ) # Fixed the variable name from grp_f to grp_l + if len(drop_inds) > 0 and len(drop_inds[0]) > 0: + stop_df = stop_df[~stop_df["main_index"].isin(drop_inds)].reset_index(drop=True) + stop_df = stop_df[["trip_id", "stop_id", "stop_sequence", "arrival_time"]] + + stop_df = stop_df.sort_values(["trip_id", "stop_sequence"]).reset_index(drop=True) + return stop_df + + +def merge_stop_geom(stop_df: pd.DataFrame, stop_loc_df: pd.DataFrame) -> gpd.GeoDataFrame: + """ + > Merge the stop_loc_df with the stop_df, and then convert the result to a GeoDataFrame + + Args: + stop_df: a dataframe of stops + stop_loc_df: a GeoDataFrame of the stops + + Returns: + A GeoDataFrame + """ + stop_df["start"] = stop_df.copy().merge(stop_loc_df, how="left", on="stop_id")["geometry"] + return stop_df + + +def create_segments(stop_df: gpd.GeoDataFrame, parallel: bool = False) -> pd.DataFrame: + """ + This function creates segments between stops based on their proximity and returns a GeoDataFrame. + + Args: + stop_df: A pandas DataFrame containing information about stops on a transit network, including + their stop_id, coordinates, and trip_id. + + Returns: + a GeoDataFrame with segments created from the input stop_df. + """ + if parallel: + stop_df = nearest_points_parallel(stop_df) + else: + stop_df = nearest_points(stop_df) + stop_df = stop_df.rename({"stop_id": "stop_id1", "arrival_time": "arrival_time1"}, axis=1) + grp = ( + pd.DataFrame(stop_df).groupby("trip_id", group_keys=False).shift(-1).reset_index(drop=True) + ) + stop_df[["stop_id2", "end", "snap_end_id", "arrival_time2"]] = grp[ + ["stop_id1", "start", "snap_start_id", "arrival_time1"] + ] + stop_df["segment_id"] = stop_df.apply( + lambda row: str(row["stop_id1"]) + "-" + str(row["stop_id2"]) + "-1", axis=1 + ) + stop_df = stop_df.dropna().reset_index(drop=True) + stop_df.snap_end_id = stop_df.snap_end_id.astype(int) + stop_df = stop_df[stop_df["snap_end_id"] > stop_df["snap_start_id"]].reset_index(drop=True) + stop_df["geometry"] = stop_df.apply( + lambda row: LineString( + row["geometry"].coords[row["snap_start_id"] : row["snap_end_id"] + 1] + ), + axis=1, + ) + return stop_df + + +def process_feed_stops(feed: Feed) -> gpd.GeoDataFrame: + """ + It takes a GTFS feed, merges the trip and shape data, filters the stop_times data to only include + the trips that are in the feed, merges the stop_times data with the stop data, creates a segment for + each stop pair, gets the EPSG zone for the feed, creates a GeoDataFrame, and calculates the length + of each segment + + Args: + feed: a GTFS feed object + max_spacing: the maximum distance between stops in meters. If a stop is more than this distance + from the previous stop, it will be dropped. + + Returns: + A GeoDataFrame with the following columns: + """ + trip_df = merge_trip_geom(feed.trips, feed.shapes) + trip_ids = set(trip_df.trip_id.unique()) + stop_loc_df = feed.stops[["stop_id", "geometry"]] + stop_df = filter_stop_df(feed.stop_times, trip_ids, stop_loc_df) + stop_df = merge_stop_geom(stop_df, stop_loc_df) + stop_df = stop_df.merge(trip_df, on="trip_id", how="left") + stops = stop_df.groupby("shape_id").count().reset_index()["geometry"] + stop_df = stop_df.groupby("shape_id").first().reset_index() + stop_df["n_stops"] = stops + epsg_zone = get_zone_epsg(stop_df) + if epsg_zone is not None: + stop_df["distance"] = stop_df.geometry.to_crs(epsg_zone).length + stop_df["mean_distance"] = stop_df["distance"] / stop_df["n_stops"] + return make_gdf(stop_df) + + +def process_feed( + feed: Feed, parallel: bool = False, max_spacing: Optional[float] = None +) -> gpd.GeoDataFrame: + """ + The function `process_feed` takes a feed and optional maximum spacing as input, performs various + data processing and filtering operations on the feed, and returns a GeoDataFrame containing the + processed data. + + Args: + feed: The `feed` parameter is a data structure that contains information about a transit network. + It likely includes data such as shapes (geometric representations of routes), trips (sequences of + stops), stop times (arrival and departure times at stops), and stops (locations of stops). + [Optional] max_spacing: The `max_spacing` parameter is an optional parameter that specifies the maximum + distance between stops. If provided, the function will filter out stops that are farther apart than + the specified maximum spacing. + + Returns: + A GeoDataFrame containing information about the stops and segments in the feed with segments smaller than the max_spacing values. + """ + # Set a Spatial Resolution and increase the resolution of the shapes + # shapes = ret_high_res_shape_parallel(feed.shapes, spat_res=5) + ## Note: Currently, the parallel version of the function ret_high_res_shape_parallel is not working as expected and is slower than the non-parallel version + shapes = ret_high_res_shape(feed.shapes, feed.trips, spat_res=5) + trip_df = merge_trip_geom(feed.trips, shapes) + trip_ids = set(trip_df.trip_id.unique()) + stop_loc_df = feed.stops[["stop_id", "geometry"]] + stop_df = filter_stop_df(feed.stop_times, trip_ids, stop_loc_df) + stop_df = merge_stop_geom(stop_df, stop_loc_df) + stop_df = stop_df.merge(trip_df, on="trip_id", how="left") + stop_df = create_segments(stop_df, parallel=parallel) + stop_df = make_gdf(stop_df) + epsg_zone = get_zone_epsg(stop_df) + if epsg_zone is not None: + stop_df["distance"] = stop_df.set_geometry("geometry").to_crs(epsg_zone).geometry.length + stop_df["distance"] = stop_df["distance"].round(2) # round to 2 decimal places + stop_df["traversal_time"] = (stop_df["arrival_time2"] - stop_df["arrival_time1"]).astype( + "float" + ) + stop_df["speed"] = stop_df["distance"].div(stop_df["traversal_time"]) + stop_df = make_segments_unique(stop_df, traversal_threshold=0) + subset_list = np.array( + [ + "segment_id", + "route_id", + "direction_id", + "trip_id", + "traversals", + "distance", + "stop_id1", + "stop_id2", + "traversal_time", + "speed", + "geometry", + ] + ) + col_subset = subset_list[np.in1d(subset_list, stop_df.columns)] + stop_df = stop_df[col_subset] + if max_spacing is not None: + stop_df = stop_df[stop_df["distance"] <= max_spacing] + return make_gdf(stop_df) + + +def inspect_feed(feed: Feed) -> str: + """ + It checks to see if the feed has any bus routes and if it has a `shape_id` column in the `trips` + table + + Args: + feed: The feed object that you want to inspect. + + Returns: + A message + """ + message = "Valid GTFS Feed" + if len(feed.stop_times) == 0: + message = "No Bus Routes in " + if "shape_id" not in feed.trips.columns: + message = "Missing `shape_id` column in " + return message + + +def get_gtfs_segments( + path: str, + agency_id: Optional[str] = None, + threshold: Optional[int] = 1, + max_spacing: Optional[float] = None, + parallel: bool = False, + date: Optional[Union[str, date_type, List[Union[str, date_type]]]] = None, + skip_invalid_dates: bool = False, +) -> gpd.GeoDataFrame: + """ + The function `get_gtfs_segments` takes a path to a GTFS feed file and returns processed GTFS segments + with comprehensive date filtering support. + + Args: + path: The path parameter is the file path to the GTFS (General Transit Feed Specification) data. + agency_id: Filter by specific transit agency ID. Defaults to None (all agencies). + threshold: Filter out trips with fewer stops than this threshold. Defaults to 1. + max_spacing: Maximum distance between consecutive stops in meters. Defaults to None (no limit). + parallel: If True, process the feed in parallel for improved performance. Defaults to False. + date: Specific date(s) to analyze. Can be: + - Single date string in YYYYMMDD format (e.g., '20230315') + - Single datetime.date object + - List of date strings and/or datetime.date objects + If None, uses the busiest day in the GTFS schedule. + skip_invalid_dates: If True and a list of dates is provided, skip invalid dates and continue + with valid ones. If False, raise an error for any invalid date. Defaults to False. + + Returns: + A GeoDataFrame containing segment information with the following columns: + - segment_id: Unique segment identifier + - stop_id1, stop_id2: Beginning and ending stop IDs + - route_id: Route identifier from the GTFS data + - direction_id: Route direction identifier + - traversals: Number of times the route traverses this segment during the measurement period + - distance: Segment length in meters + - geometry: Segment LINESTRING geometry (EPSG:4326/WGS84) + """ + feed = get_bus_feed( + path, + agency_id=agency_id, + threshold=threshold, + parallel=parallel, + date=date, + skip_invalid_dates=skip_invalid_dates + ) + df = process_feed(feed, parallel=parallel) + if max_spacing is not None: + print("Using max_spacing {:.0f} to filter segments".format(max_spacing)) + df = df[df["distance"] <= max_spacing] + return df + + +def pipeline_gtfs(filename: str, url: str, bounds: List, max_spacing: float) -> str: + """ + It takes a GTFS file, downloads it, reads it, processes it, and then outputs a bunch of files. + + Let's go through the function step by step. + + First, we define the function and give it a name. We also give it a few arguments: + + - filename: the name of the file we want to save the output to. + - url: the url of the GTFS file we want to download. + - bounds: the bounding box of the area we want to analyze. + - max_spacing: the maximum spacing we want to analyze. + + We then create a folder to save the output to. + + Next, we download the GTFS file and save it to the folder we just created. + + Then, we read the GTFS file using the `get_bus_feed` function. + + Args: + filename: the name of the file you want to save the output to + url: the url of the GTFS file + bounds: the bounding box of the area you want to analyze. This is in the format + [min_lat,min_lon,max_lat,max_lon] + max_spacing: The maximum distance between stops that you want to consider. + + Returns: + Success or Failure of the pipeline + """ + folder_path = os.path.join("output_files", filename) + gtfs_file_loc = download_write_file(url, folder_path) + + # read file using GTFS Fucntions + feed = get_bus_feed(gtfs_file_loc) + # Remove Null entries + message = inspect_feed(feed) + if message != "Valid GTFS Feed": + return failed_pipeline(message, filename, folder_path) + + df = process_feed(feed) + df_sub = df[df["distance"] < 3000].copy().reset_index(drop=True) + if len(df_sub) == 0: + return failed_pipeline("Only Long Bus Routes in ", filename, folder_path) + # Output files and Stats + summary_stats_mobility(df, folder_path, filename, url, bounds, max_spacing, export=True) + + plot_hist( + df, + file_path=os.path.join(folder_path, "spacings.png"), + title=filename.split(".")[0], + max_spacing=max_spacing, + save_fig=True, + ) + export_segments( + df, os.path.join(folder_path, "geojson"), output_format="geojson", geometry=True + ) + export_segments( + df, + os.path.join(folder_path, "spacings_with_geometry"), + output_format="csv", + geometry=True, + ) + export_segments(df, os.path.join(folder_path, "spacings"), output_format="csv", geometry=False) + return "Success for " + filename diff --git a/gtfs_segments/partridge_func.py b/gtfs_segments/partridge_func.py index 57742b2..8954d73 100644 --- a/gtfs_segments/partridge_func.py +++ b/gtfs_segments/partridge_func.py @@ -1,86 +1,205 @@ -import os -from typing import Optional - -import pandas as pd - -import gtfs_segments.partridge_mod as ptg - -from .partridge_mod.gtfs import Feed, parallel_read - - -def get_bus_feed( - path: str, agency_id: Optional[str] = None, threshold: Optional[int] = 1, parallel: bool = False -) -> Feed: - """ - The `get_bus_feed` function retrieves bus feed data from a specified path, with the option to filter - by agency name, and returns the busiest date and a GTFS feed object. - - Args: - path (str): The `path` parameter is a string that represents the path to the GTFS file. This file - contains the bus feed data. - agency_id (All): The `agency_id` parameter is an optional parameter that allows you to filter the - bus feed data by the agency name. It is used to specify the ID of the transit agency for which you - want to retrieve the bus feed data. If you provide an `agency_id`, the function will only return - data - threshold (int): The `threshold` parameter is used to filter out service IDs that have a low - frequency. It is set to a default value of 1, but you can adjust it to a different value if needed. - Service IDs with a sum of stop times greater than the threshold will be included in the returned - bus. Defaults to 1 - - Returns: - A tuple containing the busiest date and a GTFS feed object. The GTFS feed object contains - information about routes, stops, stop times, trips, and shapes for a transit agency's schedule. - """ - b_day, bday_service_ids = ptg.read_busiest_date(path) - print("Using the busiest day:", b_day) - all_days_s_ids_df = get_all_days_s_ids(path) - series = all_days_s_ids_df[bday_service_ids].sum(axis=0) > threshold - service_ids = series[series].index.values - route_types = [3, 700, 702, 703, 704, 705] # 701 is regional - # set of service IDs eliminated due to low frequency - removed_service_ids = set(bday_service_ids) - set(service_ids) - if len(removed_service_ids) > 0: - print("Service IDs eliminated due to low frequency:", removed_service_ids) - if agency_id is not None: - view = { - "routes.txt": {"route_type": route_types}, # Only bus routes - "trips.txt": {"service_id": service_ids}, # Busiest day only - "agency.txt": {"agency_id": agency_id}, # Eg: 'Société de transport de Montréal - } - else: - view = { - "routes.txt": {"route_type": route_types}, # Only bus routes - "trips.txt": {"service_id": service_ids}, # Busiest day only - } - feed = ptg.load_geo_feed(path, view=view) - if parallel: - num_cores = os.cpu_count() - print(":: Processing Feed in Parallel :: Number of cores:", num_cores) - parallel_read(feed) - return feed - - -def get_all_days_s_ids(path: str) -> pd.DataFrame: - """ - Read dates by service IDs from a given path, create a DataFrame, populate it with the dates and - service IDs, and fill missing values with False. - - Args: - path: The path to the GTFS file - - Returns: - A DataFrame containing dates and service IDs. - """ - dates_by_service_ids = ptg.read_dates_by_service_ids(path) - data = dates_by_service_ids - # Create a DataFrame - data_frame = pd.DataFrame(columns=sorted(list({col for row in data.keys() for col in row}))) - - # Iterate through the data and populate the DataFrame - for service_ids, dates in data.items(): - for date_value in dates: - data_frame.loc[date_value, list(service_ids)] = True - - # Fill missing values with False - data_frame.fillna(False, inplace=True) - return data_frame +import os +from datetime import datetime, date as date_type +from typing import Optional, Union, List + +import pandas as pd + +import gtfs_segments.partridge_mod as ptg + +# Set pandas option to handle future behavior of fillna +pd.set_option('future.no_silent_downcasting', True) + +from .partridge_mod.gtfs import Feed, parallel_read + + +def get_bus_feed( + path: str, + agency_id: Optional[str] = None, + threshold: Optional[int] = 1, + parallel: bool = False, + date: Optional[Union[str, date_type, List[Union[str, date_type]]]] = None, + skip_invalid_dates: bool = False +) -> Feed: + """ + The `get_bus_feed` function retrieves bus feed data from a specified path, with the option to filter + by agency name and date(s), and returns a GTFS feed object. + + Args: + path (str): The `path` parameter is a string that represents the path to the GTFS file. + agency_id (Optional[str]): Filter by specific agency ID. + threshold (int): Filter out service IDs with low frequency. Defaults to 1. + parallel (bool): If True, process the feed in parallel. Defaults to False. + date (Optional[Union[str, date_type, List[Union[str, date_type]]]]): Specific date(s) to restrict + services to. Can be: + - Single date string in YYYYMMDD format (e.g., '20230315') + - Single datetime.date object + - List of date strings and/or datetime.date objects + If None, defaults to the busiest date of the feed. + skip_invalid_dates (bool): If True and a list of dates is provided, skip invalid dates + and continue with valid ones. If False, raise an error for any invalid date. Defaults to False. + + Returns: + A GTFS feed object containing route and schedule information. + + Raises: + ValueError: If date format is invalid, date not found in calendar, or no services available. + """ + # Get all days service IDs for validation and threshold filtering + all_days_s_ids_df = get_all_days_s_ids(path) + + if date is not None: + # Handle single date or list of dates + if isinstance(date, (str, datetime, date_type)): + # Single date case + dates_to_process = [date] + elif isinstance(date, list): + # List of dates case + if not date: # Empty list + raise ValueError("Date list cannot be empty") + dates_to_process = date + else: + raise ValueError(f"Invalid date type: {type(date)}. Must be str, date, or list of str/date") + + # Parse and validate all dates + target_dates = [] + invalid_dates = [] + + for single_date in dates_to_process: + try: + # Parse string date if needed + if isinstance(single_date, str): + try: + parsed_date = datetime.strptime(single_date, "%Y%m%d").date() + except ValueError: + raise ValueError(f"Date string '{single_date}' must be in YYYYMMDD format (e.g., '20230315')") + elif isinstance(single_date, datetime): + parsed_date = single_date.date() + elif isinstance(single_date, date_type): + parsed_date = single_date + else: + raise ValueError(f"Invalid date type in list: {type(single_date)}") + + # Validate date exists in GTFS calendar + if parsed_date not in all_days_s_ids_df.index: + available_dates = sorted(all_days_s_ids_df.index) + error_msg = ( + f"Date {parsed_date} not found in GTFS service calendar. " + f"Available date range: {available_dates[0]} to {available_dates[-1]}" + ) + if skip_invalid_dates: + invalid_dates.append((single_date, error_msg)) + continue + else: + raise ValueError(error_msg) + + target_dates.append(parsed_date) + + except ValueError as e: + if skip_invalid_dates: + invalid_dates.append((single_date, str(e))) + continue + else: + raise e + + # Check if we have any valid dates + if not target_dates: + raise ValueError("No valid dates found in the provided date list") + + # Report invalid dates if any were skipped + if invalid_dates and skip_invalid_dates: + print(f"Warning: Skipped {len(invalid_dates)} invalid dates:") + for invalid_date, error_msg in invalid_dates: + print(f" - {invalid_date}: {error_msg}") + + # Collect all service IDs from all target dates + all_service_ids = set() + for target_date in target_dates: + date_service_ids = all_days_s_ids_df.loc[target_date][all_days_s_ids_df.loc[target_date] > 0].index.tolist() + all_service_ids.update(date_service_ids) + + bday_service_ids = list(all_service_ids) + + # Report dates being used + if len(target_dates) == 1: + print(f"Using requested date: {target_dates[0]}") + else: + print(f"Using requested dates: {sorted(target_dates)} (total: {len(target_dates)} dates)") + + # Check if any services are running on the requested dates + if not bday_service_ids: + dates_str = ', '.join(str(d) for d in sorted(target_dates)) + raise ValueError(f"No services are running on the requested date(s): {dates_str}") + + else: + # Fallback to the busiest date + target_date, bday_service_ids = ptg.read_busiest_date(path) + print(f"Using the busiest day: {target_date}") + + # Filter service_ids by threshold frequency across all days + series = all_days_s_ids_df[bday_service_ids].sum(axis=0) > threshold + service_ids = series[series].index.values.tolist() + + # Report eliminated low-frequency service IDs + removed_service_ids = set(bday_service_ids) - set(service_ids) + if len(removed_service_ids) > 0: + print("Service IDs eliminated due to low frequency:", removed_service_ids) + + # Check if any services remain after filtering + if not service_ids: + if 'target_dates' in locals(): + dates_str = ', '.join(str(d) for d in sorted(target_dates)) + else: + dates_str = str(target_date) + raise ValueError(f"No services meet the threshold requirement ({threshold}) for date(s) {dates_str}") + + # Define bus route types + route_types = [3, 700, 702, 703, 704, 705] # 701 is regional + # Build view filter + if agency_id is not None: + view = { + "routes.txt": {"route_type": route_types}, # Only bus routes + "trips.txt": {"service_id": service_ids}, # Specified/busiest day only + "agency.txt": {"agency_id": agency_id}, # Specific agency + } + else: + view = { + "routes.txt": {"route_type": route_types}, # Only bus routes + "trips.txt": {"service_id": service_ids}, # Specified/busiest day only + } + + # Load the feed + feed = ptg.load_geo_feed(path, view=view) + + if parallel: + num_cores = os.cpu_count() + print(f":: Processing Feed in Parallel :: Number of cores: {num_cores}") + parallel_read(feed) + + return feed + + +def get_all_days_s_ids(path: str) -> pd.DataFrame: + """ + Read dates by service IDs from a given path, create a DataFrame, populate it with the dates and + service IDs, and fill missing values with False. + + Args: + path: The path to the GTFS file + + Returns: + A DataFrame containing dates and service IDs. + """ + dates_by_service_ids = ptg.read_dates_by_service_ids(path) + data = dates_by_service_ids + # Create a DataFrame with explicit boolean dtype + columns = sorted(list({col for row in data.keys() for col in row})) + data_frame = pd.DataFrame(columns=columns, dtype=bool) + + # Iterate through the data and populate the DataFrame + for service_ids, dates in data.items(): + for date_value in dates: + data_frame.loc[date_value, list(service_ids)] = True + + # Fill missing values with False - now using explicit dtype + data_frame = data_frame.fillna(False) + return data_frame diff --git a/requirements.txt b/requirements.txt index 0d59ae0..0b3a7a9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,18 +1,19 @@ cython>=3.0.2 contextily>=1.2.0 geopandas>=0.12.2 -isoweek==1.3.3 +isoweek>=1.3.3 matplotlib>=3.6.2 numpy>=1.24.0 pandas>=1.5.2 -pytest==7.2.0 -requests==2.32.2 -scipy==1.10.0 -setuptools==78.1.1 -shapely==2.0.0 -utm==0.7.0 -rasterio==1.3.6 +pytest>=7.2.0 +requests>=2.32.2 +scipy>=1.10.0 +setuptools>=78.1.1 +shapely>=2.0.0 +utm>=0.7.0 +rasterio>=1.3.6 faust-cchardet>=2.1.7 -isoweek==1.3.3 charset_normalizer>=3.3.0 -thefuzz>=0.22.1 \ No newline at end of file +thefuzz>=0.22.1 +folium>=0.20.0 +branca>=0.6.0 \ No newline at end of file diff --git a/test/test_gtfs_segments.py b/test/test_gtfs_segments.py index 4f15302..dbfdd4c 100644 --- a/test/test_gtfs_segments.py +++ b/test/test_gtfs_segments.py @@ -2,6 +2,7 @@ import os import unittest +from datetime import date import geopandas as gpd @@ -69,3 +70,288 @@ def test_get_gtfs_segments(self): df_max_spacing["distance"].min() >= 0, "Min spacing should be greater than or equal to 0", ) + + def test_get_gtfs_segments_with_date_string(self): + """ + Test get_gtfs_segments with date parameter using string format (YYYYMMDD). + Uses a weekday (2022-01-11) from the Ann Arbor GTFS data that has sufficient service. + """ + # Test with a Tuesday date that has good service coverage + tuesday_date = "20220111" # Tuesday from Ann Arbor GTFS data + df_tuesday = get_gtfs_segments(self.gtfs_path, date=tuesday_date) + + self.assertTrue( + type(df_tuesday) == gpd.GeoDataFrame, + "get_gtfs_segments with date string should return GeoDataFrame", + ) + self.assertTrue( + len(df_tuesday) > 0, + "Tuesday segments should contain data", + ) + + def test_get_gtfs_segments_with_date_object(self): + """ + Test get_gtfs_segments with date parameter using datetime.date object. + """ + # Test with a Tuesday date using datetime.date object + tuesday_date = date(2022, 1, 11) # Tuesday from Ann Arbor GTFS data + df_tuesday = get_gtfs_segments(self.gtfs_path, date=tuesday_date) + + self.assertTrue( + type(df_tuesday) == gpd.GeoDataFrame, + "get_gtfs_segments with date object should return GeoDataFrame", + ) + self.assertTrue( + len(df_tuesday) > 0, + "Tuesday segments should contain data", + ) + + def test_get_gtfs_segments_date_vs_busiest_day(self): + """ + Test that specifying a date gives different results than the default busiest day. + This test focuses on verifying that the date parameter is being processed correctly. + """ + # Default behavior (busiest day: 2022-01-10) + df_busiest = get_gtfs_segments(self.gtfs_path) + + # Specific Tuesday (2022-01-11) which should have different service pattern + df_tuesday = get_gtfs_segments(self.gtfs_path, date="20220111") + + # Both should return valid GeoDataFrames + self.assertTrue( + type(df_busiest) == gpd.GeoDataFrame and type(df_tuesday) == gpd.GeoDataFrame, + "Both should return GeoDataFrames", + ) + + # Check that both have the expected columns + expected_columns = ["segment_id", "route_id", "traversals", "distance", "geometry"] + for col in expected_columns: + self.assertTrue( + col in df_busiest.columns, + f"Busiest day result should contain {col} column", + ) + self.assertTrue( + col in df_tuesday.columns, + f"Tuesday result should contain {col} column", + ) + + def test_get_gtfs_segments_date_combined_parameters(self): + """ + Test get_gtfs_segments with date parameter combined with other parameters. + """ + df_combined = get_gtfs_segments( + self.gtfs_path, + date="20220111", # Tuesday + threshold=2, + max_spacing=2000, + agency_id="1" + ) + + self.assertTrue( + type(df_combined) == gpd.GeoDataFrame, + "Combined parameters with date should return GeoDataFrame", + ) + if len(df_combined) > 0: # Only check if we have data + self.assertTrue( + df_combined["distance"].max() <= 2000, + "Max spacing constraint should be applied with date parameter", + ) + + def test_get_gtfs_segments_invalid_date_format(self): + """ + Test that invalid date format raises appropriate error. + """ + with self.assertRaises(ValueError) as context: + get_gtfs_segments(self.gtfs_path, date="2022-01-16") # Wrong format + + self.assertTrue( + "YYYYMMDD format" in str(context.exception), + "Should provide helpful error message for invalid date format", + ) + + def test_get_gtfs_segments_date_not_in_calendar(self): + """ + Test that date not in GTFS calendar raises appropriate error. + """ + with self.assertRaises(ValueError) as context: + get_gtfs_segments(self.gtfs_path, date="20250101") # Future date not in GTFS + + self.assertTrue( + "not found in GTFS service calendar" in str(context.exception), + "Should provide helpful error message for date not in calendar", + ) + + def test_get_bus_feed_with_date(self): + """ + Test get_bus_feed function directly with date parameter. + """ + # Test string date + feed_with_date = get_bus_feed(self.gtfs_path, date="20220116") + self.assertTrue( + hasattr(feed_with_date, 'trips') and len(feed_with_date.trips) > 0, + "get_bus_feed with date should return valid feed with trips", + ) + + # Test datetime.date object + feed_with_date_obj = get_bus_feed(self.gtfs_path, date=date(2022, 1, 11)) + self.assertTrue( + hasattr(feed_with_date_obj, 'trips') and len(feed_with_date_obj.trips) > 0, + "get_bus_feed with date object should return valid feed with trips", + ) + + def test_get_gtfs_segments_with_date_list(self): + """ + Test get_gtfs_segments with a list of dates. + """ + # Test with multiple weekdays + date_list = ["20220110", "20220111", "20220112"] # Mon, Tue, Wed + df_multi = get_gtfs_segments(self.gtfs_path, date=date_list) + + self.assertTrue( + type(df_multi) == gpd.GeoDataFrame, + "get_gtfs_segments with date list should return GeoDataFrame", + ) + self.assertTrue( + len(df_multi) > 0, + "Multi-date segments should contain data", + ) + + def test_get_gtfs_segments_mixed_date_types(self): + """ + Test get_gtfs_segments with mixed date types (strings and date objects). + """ + from datetime import date + + # Mix of string and date object + date_list = ["20220110", date(2022, 1, 11)] + df_mixed = get_gtfs_segments(self.gtfs_path, date=date_list) + + self.assertTrue( + type(df_mixed) == gpd.GeoDataFrame, + "get_gtfs_segments with mixed date types should return GeoDataFrame", + ) + self.assertTrue( + len(df_mixed) > 0, + "Mixed date type segments should contain data", + ) + + def test_get_gtfs_segments_skip_invalid_dates(self): + """ + Test get_gtfs_segments with skip_invalid_dates=True. + """ + # Include one valid and one invalid date + date_list = ["20220110", "20250101"] # Valid date and future invalid date + df_skip = get_gtfs_segments( + self.gtfs_path, + date=date_list, + skip_invalid_dates=True + ) + + self.assertTrue( + type(df_skip) == gpd.GeoDataFrame, + "get_gtfs_segments with skip_invalid_dates should return GeoDataFrame", + ) + self.assertTrue( + len(df_skip) > 0, + "Should process valid dates when skipping invalid ones", + ) + + def test_get_gtfs_segments_all_invalid_dates_with_skip(self): + """ + Test error handling when all dates are invalid but skip_invalid_dates=True. + """ + # All invalid dates + date_list = ["20250101", "20250102"] + + with self.assertRaises(ValueError) as context: + get_gtfs_segments( + self.gtfs_path, + date=date_list, + skip_invalid_dates=True + ) + + self.assertIn("No valid dates found", str(context.exception)) + + def test_get_gtfs_segments_invalid_date_list_strict(self): + """ + Test error handling for invalid dates in list with skip_invalid_dates=False. + """ + # Include one invalid date with strict validation + date_list = ["20220110", "20250101"] + + with self.assertRaises(ValueError) as context: + get_gtfs_segments(self.gtfs_path, date=date_list, skip_invalid_dates=False) + + self.assertIn("not found in GTFS service calendar", str(context.exception)) + + def test_get_gtfs_segments_empty_date_list(self): + """ + Test error handling for empty date list. + """ + with self.assertRaises(ValueError) as context: + get_gtfs_segments(self.gtfs_path, date=[]) + + self.assertIn("Date list cannot be empty", str(context.exception)) + + def test_get_gtfs_segments_date_list_with_all_parameters(self): + """ + Test get_gtfs_segments with date list combined with all other parameters. + """ + df_comprehensive = get_gtfs_segments( + self.gtfs_path, + date=["20220110", "20220111"], + agency_id="1", + threshold=2, + max_spacing=2000, + parallel=False, + skip_invalid_dates=False + ) + + self.assertTrue( + type(df_comprehensive) == gpd.GeoDataFrame, + "Comprehensive parameters with date list should return GeoDataFrame", + ) + + if len(df_comprehensive) > 0: + self.assertTrue( + df_comprehensive["distance"].max() <= 2000, + "Max spacing should be applied with date list", + ) + + def test_get_bus_feed_with_date_list(self): + """ + Test get_bus_feed with a list of dates. + """ + # Test with multiple dates + date_list = ["20220110", "20220111"] + feed_multi = get_bus_feed(self.gtfs_path, date=date_list) + + self.assertTrue( + hasattr(feed_multi, 'trips'), + "get_bus_feed with date list should return valid feed", + ) + self.assertTrue( + len(feed_multi.trips) > 0, + "Multi-date feed should contain trips", + ) + + def test_get_bus_feed_skip_invalid_dates(self): + """ + Test get_bus_feed with skip_invalid_dates parameter. + """ + # Include one valid and one invalid date + date_list = ["20220110", "20250101"] + feed_skip = get_bus_feed( + self.gtfs_path, + date=date_list, + skip_invalid_dates=True + ) + + self.assertTrue( + hasattr(feed_skip, 'trips'), + "get_bus_feed with skip_invalid_dates should return valid feed", + ) + self.assertTrue( + len(feed_skip.trips) > 0, + "Should process valid dates when skipping invalid ones", + )