diff --git a/src/xpk/commands/cluster.py b/src/xpk/commands/cluster.py index 2b93c0daf..f59cf98ab 100644 --- a/src/xpk/commands/cluster.py +++ b/src/xpk/commands/cluster.py @@ -78,6 +78,7 @@ from ..utils.file import write_tmp_file from . import cluster_gcluster from .common import set_cluster_command +import shlex import shutil import os @@ -1054,6 +1055,128 @@ def run_gke_clusters_list_command(args) -> int: return 0 +def parse_command_args_to_dict(arg_string: str) -> dict: + """Parses a command-line argument string into a dictionary of parameters. + + This function safely splits a command-line string, handling quoted arguments + and different parameter formats (e.g., --flag, --key=value, --key value). + It's designed to help convert user-provided custom arguments into a structured + format for easier merging and de-duplication. + + Args: + arg_string: A string containing command-line arguments, such as + "--master-ipv4-cidr=10.0.0.0/28 --enable-ip-alias". + + Returns: + A dictionary where keys are parameter names (e.g., "--enable-ip-alias", + "--cluster-ipv4-cidr") and values are their corresponding parsed values + (e.g., True for a boolean flag, "10.0.0.0/28" for a string value). + """ + parsed_args = {} + if not arg_string: + return parsed_args + + tokens = shlex.split(arg_string) + # After shlex.split: Print the tokens list + xpk_print(f'Shlex-split tokens: {tokens}') + i = 0 + while i < len(tokens): + token = tokens[i] + if token.startswith('--'): + if '=' in token: + key, value = token.split('=', 1) + parsed_args[key] = value + else: + if i + 1 < len(tokens) and not tokens[i + 1].startswith('--'): + parsed_args[token] = tokens[i + 1] + i += 1 + else: + parsed_args[token] = True + elif token.startswith('-'): + pass + i += 1 + # After parsing: Print the final parsed dictionary + xpk_print(f'Final parsed_args: {parsed_args}') + xpk_print('-------------------------------------------') + return parsed_args + + +def process_gcloud_args(user_parsed_args, final_gcloud_args): + """ + Processes custom cluster arguments and updates the final gcloud arguments dictionary. + + This function handles special cases for '--no-' and '--enable-' prefixes + in custom arguments to correctly modify the gcloud arguments. + + """ + for key, value in user_parsed_args.items(): + if key.startswith('--no-'): + opposite_key = f'--{key[5:]}' + if opposite_key in final_gcloud_args: + del final_gcloud_args[opposite_key] + final_gcloud_args[key] = True + elif key.startswith('--enable-'): + opposite_key = f'--no-{key[2:]}' + opposite_disable_key = f'--disable-{key[9:]}' + if opposite_key in final_gcloud_args: + del final_gcloud_args[opposite_key] + if opposite_disable_key in final_gcloud_args: + del final_gcloud_args[opposite_disable_key] + final_gcloud_args[key] = value + elif key.startswith('--disable-'): + feature_name = key[10:] + opposite_enable_key = f'--enable-{feature_name}' + if opposite_enable_key in final_gcloud_args: + del final_gcloud_args[opposite_enable_key] + final_gcloud_args[key] = True + else: + # For all other arguments, simply add or update their values. + final_gcloud_args[key] = value + + +def merge_conditional_params(conditional_params, final_gcloud_args): + """ + Merge conditional parameters into the final gcloud arguments dictionary. Specifically handle the --addons parameter by merging its values. + """ + for key, value in conditional_params.items(): + if key not in final_gcloud_args: + final_gcloud_args[key] = value + elif key == '--addons' and key in final_gcloud_args: + final_gcloud_args[key] = ','.join( + list(set(final_gcloud_args[key].split(',') + value.split(','))) + ) + + +def construct_gcloud_command_string( + cluster_name: str, gcloud_args: dict +) -> str: + """ + Constructs the gcloud command string from a dictionary of arguments. + + Args: + cluster_name: The name of the cluster. + gcloud_args: A dictionary where keys are gcloud argument names + and values are their corresponding parsed values. + + Returns: + A complete gcloud command string. + """ + command_parts = ['gcloud beta container clusters create', cluster_name] + + for key, value in gcloud_args.items(): + if value is True: + command_parts.append(key) + elif value is False: + pass + elif value is not None and str(value).strip() != '': + if ' ' in str(value): + command_parts.append(f'{key}="{value}"') + else: + command_parts.append(f'{key}={value}') + + return ' '.join(command_parts) + + def run_gke_cluster_create_command( args, gke_control_plane_version: str, system: SystemCharacteristics ) -> int: @@ -1077,59 +1200,50 @@ def run_gke_cluster_create_command( ) machine_type = args.cluster_cpu_machine_type - # Create the regional cluster with `num-nodes` CPU nodes in the same zone as - # TPUs. This has been tested with clusters of 300 VMs. Larger clusters will - # benefit from a larger initial `--num-nodes`. After the cluster is created, - # the auto-scaler can reduce/increase the nodes based on the load. + final_gcloud_args = {} + final_gcloud_args['--project'] = args.project + final_gcloud_args['--region'] = zone_to_region(args.zone) + final_gcloud_args['--node-locations'] = args.zone + final_gcloud_args['--cluster-version'] = gke_control_plane_version + final_gcloud_args['--machine-type'] = machine_type + final_gcloud_args['--enable-autoscaling'] = True + final_gcloud_args['--total-min-nodes'] = 1 + final_gcloud_args['--total-max-nodes'] = 1000 + final_gcloud_args['--num-nodes'] = args.default_pool_cpu_num_nodes + final_gcloud_args['--enable-dns-access'] = True + # This value is from here: https://cloud.google.com/kubernetes-engine/docs/how-to/legacy/network-isolation + final_gcloud_args['--master-ipv4-cidr'] = '172.16.0.32/28' + # This value is from here https://cloud.google.com/vpc/docs/subnets + final_gcloud_args['--cluster-ipv4-cidr'] = '10.224.0.0/12' + final_gcloud_args['--enable-private-nodes'] = True + final_gcloud_args['--enable-ip-alias'] = True + final_gcloud_args['--autoscaling-profile'] = 'optimize-utilization' - # If the user passes in the gke version then we use that directly instead of the rapid release. - # This allows users to directly pass a specified gke version without release channel constraints. - rapid_release_cmd = '' if args.gke_version is not None: - rapid_release_cmd = ' --release-channel rapid' - - command = ( - 'gcloud beta container clusters create' - f' {args.cluster} --project={args.project}' - f' --region={zone_to_region(args.zone)}' - f' --node-locations={args.zone}' - f' --cluster-version={gke_control_plane_version}' - f' --machine-type={machine_type}' - ' --enable-autoscaling' - ' --total-min-nodes 1 --total-max-nodes 1000' - f' --num-nodes {args.default_pool_cpu_num_nodes}' - f' {args.custom_cluster_arguments}' - f' {rapid_release_cmd}' - ' --enable-dns-access' - ' --autoscaling-profile=optimize-utilization' - ) - - enable_ip_alias = False + final_gcloud_args['--release-channel'] = 'rapid' + conditional_params = {} if args.private or args.authorized_networks is not None: - enable_ip_alias = True - command += ' --enable-master-authorized-networks --enable-private-nodes' + conditional_params['--enable-master-authorized-networks'] = True + conditional_params['--enable-private-nodes'] = True + conditional_params['--enable-ip-alias'] = True if system.accelerator_type == AcceleratorType['GPU']: - enable_ip_alias = True - command += ( - ' --enable-dataplane-v2' - ' --enable-multi-networking --no-enable-autoupgrade' - ) + conditional_params['--enable-dataplane-v2'] = True + conditional_params['--enable-multi-networking'] = True + conditional_params['--no-enable-autoupgrade'] = True + conditional_params['--enable-ip-alias'] = True else: - command += ' --location-policy=BALANCED --scopes=storage-full,gke-default' - + conditional_params['--location-policy'] = 'BALANCED' + conditional_params['--scopes'] = 'storage-full,gke-default' if args.enable_pathways: - enable_ip_alias = True - - if enable_ip_alias: - command += ' --enable-ip-alias' + conditional_params['--enable-ip-alias'] = True if args.enable_ray_cluster: - command += ' --addons RayOperator' + conditional_params['--addons'] = 'RayOperator' if args.enable_workload_identity or args.enable_gcsfuse_csi_driver: - command += f' --workload-pool={args.project}.svc.id.goog' + conditional_params['--workload-pool'] = f'{args.project}.svc.id.goog' addons = [] if args.enable_gcsfuse_csi_driver: @@ -1146,14 +1260,18 @@ def run_gke_cluster_create_command( if args.enable_lustre_csi_driver: addons.append('LustreCsiDriver') - command += ' --enable-legacy-lustre-port' + conditional_params['--enable-legacy-lustre-port'] = True if hasattr(args, 'enable_mtc') and args.enable_mtc: addons.append('HighScaleCheckpointing') if len(addons) > 0: - addons_str = ','.join(addons) - command += f' --addons={addons_str}' + conditional_params['--addons'] = ','.join(addons) + + merge_conditional_params(conditional_params, final_gcloud_args) + user_parsed_args = parse_command_args_to_dict(args.custom_cluster_arguments) + process_gcloud_args(user_parsed_args, final_gcloud_args) + command = construct_gcloud_command_string(args.cluster, final_gcloud_args) return_code = run_command_with_updates(command, 'GKE Cluster Create', args) if return_code != 0: diff --git a/src/xpk/commands/tests/unit/test_arg_parser.py b/src/xpk/commands/tests/unit/test_arg_parser.py new file mode 100644 index 000000000..a75b7935a --- /dev/null +++ b/src/xpk/commands/tests/unit/test_arg_parser.py @@ -0,0 +1,58 @@ +"""Unit tests for the arg_parser module in xpk.commands.""" + +import unittest +from src.xpk.commands.cluster import parse_command_args_to_dict + + +class TestParseCommandArgsToDict(unittest.TestCase): + """Tests the parse_command_args_to_dict function from the cluster module.""" + + def test_empty_string(self): + self.assertEqual(parse_command_args_to_dict(''), {}) + + def test_simple_key_value_pairs(self): + result = parse_command_args_to_dict('--key1=value1 --key2=value2') + self.assertEqual(result, {'--key1': 'value1', '--key2': 'value2'}) + + def test_flag_with_space_value(self): + result = parse_command_args_to_dict('--key1 value1 --key2 value2') + self.assertEqual(result, {'--key1': 'value1', '--key2': 'value2'}) + + def test_boolean_flags(self): + result = parse_command_args_to_dict('--enable-feature --no-logs') + self.assertEqual(result, {'--enable-feature': True, '--no-logs': True}) + + def test_mixed_formats(self): + result = parse_command_args_to_dict( + '--project=my-project --zone us-central1 --dry-run' + ) + self.assertEqual( + result, + {'--project': 'my-project', '--zone': 'us-central1', '--dry-run': True}, + ) + + def test_quoted_values(self): + result = parse_command_args_to_dict( + '--description "My cluster with spaces" --name=test-cluster' + ) + self.assertEqual( + result, + {'--description': 'My cluster with spaces', '--name': 'test-cluster'}, + ) + + def test_no_double_hyphen_flags(self): + result = parse_command_args_to_dict('random-word -f --flag') + self.assertEqual(result, {'--flag': True}) # Only --flag should be parsed + + def test_duplicate_keys_last_one_wins(self): + result = parse_command_args_to_dict('--key=value1 --key=value2') + self.assertEqual(result, {'--key': 'value2'}) + + def test_hyphenated_keys(self): + result = parse_command_args_to_dict('--api-endpoint=some-url') + self.assertEqual(result, {'--api-endpoint': 'some-url'}) + + +if __name__ == '__main__': + # Run python3 -m src.xpk.commands.tests.unit.test_arg_parser under the xpk folder. + unittest.main() diff --git a/src/xpk/commands/tests/unit/test_gcloud_arg_processor.py b/src/xpk/commands/tests/unit/test_gcloud_arg_processor.py new file mode 100644 index 000000000..efaa8ac3f --- /dev/null +++ b/src/xpk/commands/tests/unit/test_gcloud_arg_processor.py @@ -0,0 +1,118 @@ +"""Unit tests for the gcloud_arg_parser module in xpk.commands.""" + +import unittest +from src.xpk.commands.cluster import process_gcloud_args + + +class TestProcessGcloudArgs(unittest.TestCase): + """Tests the process_gcloud_args function from the cluster module.""" + + def test_add_new_argument(self): + final_args = {'--existing-key': 'existing-value'} + user_args = {'--new-key': 'new-value'} + process_gcloud_args(user_args, final_args) + self.assertEqual( + final_args, + {'--existing-key': 'existing-value', '--new-key': 'new-value'}, + ) + + def test_override_existing_argument(self): + final_args = {'--common-key': 'old-value'} + user_args = {'--common-key': 'new-value'} + process_gcloud_args(user_args, final_args) + self.assertEqual(final_args, {'--common-key': 'new-value'}) + + def test_no_enable_flag_overrides_enable(self): + final_args = {'--enable-logging': True} + user_args = {'--no-enable-logging': True} + process_gcloud_args(user_args, final_args) + self.assertEqual(final_args, {'--no-enable-logging': True}) + self.assertNotIn('--enable-logging', final_args) + + def test_enable_flag_overrides_no_enable(self): + final_args = {'--no-enable-monitoring': True} + user_args = {'--enable-monitoring': True} + process_gcloud_args(user_args, final_args) + self.assertEqual(final_args, {'--enable-monitoring': True}) + self.assertNotIn('--no-enable-monitoring', final_args) + + def test_no_conflict(self): + final_args = {'--param1': 'value1'} + user_args = {'--param2': 'value2'} + process_gcloud_args(user_args, final_args) + self.assertEqual(final_args, {'--param1': 'value1', '--param2': 'value2'}) + + def test_empty_user_args(self): + final_args = {'--param1': 'value1'} + user_args = {} + process_gcloud_args(user_args, final_args) + self.assertEqual(final_args, {'--param1': 'value1'}) + + def test_complex_overrides(self): + final_args = { + '--zone': 'us-east1-b', + '--enable-ip-alias': True, + '--machine-type': 'n1-standard-4', + '--no-enable-public-ip': ( + True # This will be removed if --enable-public-ip is set + ), + } + user_args = { + '--zone': 'us-central1-a', # Overrides + '--no-enable-ip-alias': True, # Overrides --enable-ip-alias + '--disk-size': '200GB', # New + '--enable-public-ip': True, # Overrides --no-enable-public-ip + } + process_gcloud_args(user_args, final_args) + self.assertEqual( + final_args, + { + '--zone': 'us-central1-a', + '--no-enable-ip-alias': True, + '--machine-type': 'n1-standard-4', # Not affected + '--disk-size': '200GB', + '--enable-public-ip': True, + }, + ) + self.assertNotIn('--enable-ip-alias', final_args) + self.assertNotIn('--no-enable-public-ip', final_args) + + def test_disable_flag_is_added(self): + """ + Tests that a --disable- flag from user_args is simply added to final_args + when no conflicting --enable- or --no-enable- flag exists. + """ + final_args = {'--existing-flag': 'value'} + user_args = {'--disable-dataplane-v2': True} + process_gcloud_args(user_args, final_args) + self.assertEqual( + final_args, {'--existing-flag': 'value', '--disable-dataplane-v2': True} + ) + + def test_enable_flag_overrides_disable(self): + """ + Tests that an --enable- flag from user_args overrides a --disable- flag + present in final_args. + """ + final_args = {'--disable-logging': True} # Existing disable flag + user_args = {'--enable-logging': True} # User wants to enable + process_gcloud_args(user_args, final_args) + self.assertEqual(final_args, {'--enable-logging': True}) + self.assertNotIn('--disable-logging', final_args) + + def test_disable_flag_overrides_enable_from_user_args_order(self): + """ + Tests that if --enable- is in final_args and --disable- is in user_args, + the --disable- from user_args takes precedence and removes the --enable-. + This is implied by the order of processing (user_args overwrite/remove final_args). + """ + final_args = {'--enable-some-feature': True} + user_args = {'--disable-some-feature': True} + process_gcloud_args(user_args, final_args) + self.assertEqual(final_args, {'--disable-some-feature': True}) + self.assertNotIn('--enable-some-feature', final_args) + + +if __name__ == '__main__': + # Run python3 -m src.xpk.commands.tests.unit.test_gcloud_arg_processor under the xpk folder. + unittest.main()