Skip to content
175 changes: 134 additions & 41 deletions src/xpk/commands/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
from ..utils.file import write_tmp_file
from . import cluster_gcluster
from .common import set_cluster_command
import shlex


def cluster_adapt(args) -> None:
Expand Down Expand Up @@ -794,6 +795,85 @@ def run_gke_clusters_list_command(args) -> int:
return 0


def parse_command_args_to_dict(arg_string: str) -> dict:
"""Parses a command-line argument string into a dictionary of parameters.

This function safely splits a command-line string, handling quoted arguments
and different parameter formats (e.g., --flag, --key=value, --key value).
It's designed to help convert user-provided custom arguments into a structured
format for easier merging and de-duplication.

Args:
arg_string: A string containing command-line arguments, such as
"--master-ipv4-cidr=10.0.0.0/28 --enable-ip-alias".

Returns:
A dictionary where keys are parameter names (e.g., "--enable-ip-alias",
"--cluster-ipv4-cidr") and values are their corresponding parsed values
(e.g., True for a boolean flag, "10.0.0.0/28" for a string value).
"""
parsed_args = {}
if not arg_string:
return parsed_args

tokens = shlex.split(arg_string)
# After shlex.split: Print the tokens list
xpk_print(f'Shlex-split tokens: {tokens}')
i = 0
while i < len(tokens):
token = tokens[i]
if token.startswith('--'):
if '=' in token:
key, value = token.split('=', 1)
parsed_args[key] = value
else:
if i + 1 < len(tokens) and not tokens[i + 1].startswith('--'):
parsed_args[token] = tokens[i + 1]
i += 1
else:
parsed_args[token] = True
elif token.startswith('-'):
pass
i += 1
# After parsing: Print the final parsed dictionary
xpk_print(f'Final parsed_args: {parsed_args}')
xpk_print('-------------------------------------------')
return parsed_args


def process_gcloud_args(user_parsed_args, final_gcloud_args):
"""
Processes custom cluster arguments and updates the final gcloud arguments dictionary.

This function handles special cases for '--no-' and '--enable-' prefixes
in custom arguments to correctly modify the gcloud arguments.

"""
for key, value in user_parsed_args.items():
if key.startswith('--no-'):
opposite_key = f'--{key[5:]}'
if opposite_key in final_gcloud_args:
del final_gcloud_args[opposite_key]
final_gcloud_args[key] = True
elif key.startswith('--enable-'):
opposite_key = f'--no-{key[2:]}'
opposite_disable_key = f'--disable-{key[9:]}'
if opposite_key in final_gcloud_args:
del final_gcloud_args[opposite_key]
if opposite_disable_key in final_gcloud_args:
del final_gcloud_args[opposite_disable_key]
final_gcloud_args[key] = value
elif key.startswith('--disable-'):
feature_name = key[10:]
opposite_enable_key = f'--enable-{feature_name}'
if opposite_enable_key in final_gcloud_args:
del final_gcloud_args[opposite_enable_key]
final_gcloud_args[key] = True
else:
# For all other arguments, simply add or update their values.
final_gcloud_args[key] = value


def run_gke_cluster_create_command(
args, gke_control_plane_version: str, system: SystemCharacteristics
) -> int:
Expand All @@ -817,58 +897,48 @@ def run_gke_cluster_create_command(
)
machine_type = args.cluster_cpu_machine_type

# Create the regional cluster with `num-nodes` CPU nodes in the same zone as
# TPUs. This has been tested with clusters of 300 VMs. Larger clusters will
# benefit from a larger initial `--num-nodes`. After the cluster is created,
# the auto-scaler can reduce/increase the nodes based on the load.
final_gcloud_args = {}
final_gcloud_args['--project'] = args.project
final_gcloud_args['--region'] = zone_to_region(args.zone)
final_gcloud_args['--node-locations'] = args.zone
final_gcloud_args['--cluster-version'] = gke_control_plane_version
final_gcloud_args['--machine-type'] = machine_type
final_gcloud_args['--enable-autoscaling'] = True
final_gcloud_args['--total-min-nodes'] = 1
final_gcloud_args['--total-max-nodes'] = 1000
final_gcloud_args['--num-nodes'] = args.default_pool_cpu_num_nodes
final_gcloud_args['--enable-dns-access'] = True
# This value is from here: https://cloud.google.com/kubernetes-engine/docs/how-to/legacy/network-isolation
final_gcloud_args['--master-ipv4-cidr'] = '172.16.0.32/28'
# This value is from here https://cloud.google.com/vpc/docs/subnets
final_gcloud_args['--cluster-ipv4-cidr'] = '10.224.0.0/12'

# If the user passes in the gke version then we use that directly instead of the rapid release.
# This allows users to directly pass a specified gke version without release channel constraints.
rapid_release_cmd = ''
if args.gke_version is not None:
rapid_release_cmd = ' --release-channel rapid'

command = (
'gcloud beta container clusters create'
f' {args.cluster} --project={args.project}'
f' --region={zone_to_region(args.zone)}'
f' --node-locations={args.zone}'
f' --cluster-version={gke_control_plane_version}'
f' --machine-type={machine_type}'
' --enable-autoscaling'
' --total-min-nodes 1 --total-max-nodes 1000'
f' --num-nodes {args.default_pool_cpu_num_nodes}'
f' {args.custom_cluster_arguments}'
f' {rapid_release_cmd}'
' --enable-dns-access'
)
final_gcloud_args['--release-channel'] = 'rapid'

enable_ip_alias = False
conditional_params = {}

if args.private or args.authorized_networks is not None:
enable_ip_alias = True
command += ' --enable-master-authorized-networks --enable-private-nodes'
conditional_params['--enable-master-authorized-networks'] = True
conditional_params['--enable-private-nodes'] = True
conditional_params['--enable-ip-alias'] = True

if system.accelerator_type == AcceleratorType['GPU']:
enable_ip_alias = True
command += (
' --enable-dataplane-v2'
' --enable-multi-networking --no-enable-autoupgrade'
)
conditional_params['--enable-dataplane-v2'] = True
conditional_params['--enable-multi-networking'] = True
conditional_params['--no-enable-autoupgrade'] = True
conditional_params['--enable-ip-alias'] = True
else:
command += ' --location-policy=BALANCED --scopes=storage-full,gke-default'

conditional_params['--location-policy'] = 'BALANCED'
conditional_params['--scopes'] = 'storage-full,gke-default'
if args.enable_pathways:
enable_ip_alias = True

if enable_ip_alias:
command += ' --enable-ip-alias'
conditional_params['--enable-ip-alias'] = True

if args.enable_ray_cluster:
command += ' --addons RayOperator'
conditional_params['--addons'] = 'RayOperator'

if args.enable_workload_identity or args.enable_gcsfuse_csi_driver:
command += f' --workload-pool={args.project}.svc.id.goog'
conditional_params['--workload-pool'] = f'{args.project}.svc.id.goog'

addons = []
if args.enable_gcsfuse_csi_driver:
Expand All @@ -887,8 +957,31 @@ def run_gke_cluster_create_command(
addons.append('HighScaleCheckpointing')

if len(addons) > 0:
addons_str = ','.join(addons)
command += f' --addons={addons_str}'
conditional_params['--addons'] = ','.join(addons)

for key, value in conditional_params.items():
if key not in final_gcloud_args:
final_gcloud_args[key] = value
elif key == '--addons' and key in final_gcloud_args:
final_gcloud_args[key] = ','.join(
list(set(final_gcloud_args[key].split(',') + value.split(',')))
)

user_parsed_args = parse_command_args_to_dict(args.custom_cluster_arguments)
process_gcloud_args(user_parsed_args, final_gcloud_args)

command_parts = ['gcloud beta container clusters create', args.cluster]
for key, value in final_gcloud_args.items():
if value is True:
command_parts.append(key)
elif value is False:
pass
elif value is not None and value != ' ':
if ' ' in str(value):
command_parts.append(f'{key}="{value}"')
else:
command_parts.append(f'{key}={value}')
command = ' '.join(command_parts)

return_code = run_command_with_updates(command, 'GKE Cluster Create', args)
if return_code != 0:
Expand Down
58 changes: 58 additions & 0 deletions src/xpk/commands/tests/unit/test_arg_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
"""Unit tests for the arg_parser module in xpk.commands."""

import unittest
from src.xpk.commands.cluster import parse_command_args_to_dict


class TestParseCommandArgsToDict(unittest.TestCase):
"""Tests the parse_command_args_to_dict function from the cluster module."""

def test_empty_string(self):
self.assertEqual(parse_command_args_to_dict(''), {})

def test_simple_key_value_pairs(self):
result = parse_command_args_to_dict('--key1=value1 --key2=value2')
self.assertEqual(result, {'--key1': 'value1', '--key2': 'value2'})

def test_flag_with_space_value(self):
result = parse_command_args_to_dict('--key1 value1 --key2 value2')
self.assertEqual(result, {'--key1': 'value1', '--key2': 'value2'})

def test_boolean_flags(self):
result = parse_command_args_to_dict('--enable-feature --no-logs')
self.assertEqual(result, {'--enable-feature': True, '--no-logs': True})

def test_mixed_formats(self):
result = parse_command_args_to_dict(
'--project=my-project --zone us-central1 --dry-run'
)
self.assertEqual(
result,
{'--project': 'my-project', '--zone': 'us-central1', '--dry-run': True},
)

def test_quoted_values(self):
result = parse_command_args_to_dict(
'--description "My cluster with spaces" --name=test-cluster'
)
self.assertEqual(
result,
{'--description': 'My cluster with spaces', '--name': 'test-cluster'},
)

def test_no_double_hyphen_flags(self):
result = parse_command_args_to_dict('random-word -f --flag')
self.assertEqual(result, {'--flag': True}) # Only --flag should be parsed

def test_duplicate_keys_last_one_wins(self):
result = parse_command_args_to_dict('--key=value1 --key=value2')
self.assertEqual(result, {'--key': 'value2'})

def test_hyphenated_keys(self):
result = parse_command_args_to_dict('--api-endpoint=some-url')
self.assertEqual(result, {'--api-endpoint': 'some-url'})


if __name__ == '__main__':
# Run python3 -m src.xpk.commands.tests.unit.test_arg_parser under the xpk folder.
unittest.main()
118 changes: 118 additions & 0 deletions src/xpk/commands/tests/unit/test_gcloud_arg_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
"""Unit tests for the gcloud_arg_parser module in xpk.commands."""

import unittest
from src.xpk.commands.cluster import process_gcloud_args


class TestProcessGcloudArgs(unittest.TestCase):
"""Tests the process_gcloud_args function from the cluster module."""

def test_add_new_argument(self):
final_args = {'--existing-key': 'existing-value'}
user_args = {'--new-key': 'new-value'}
process_gcloud_args(user_args, final_args)
self.assertEqual(
final_args,
{'--existing-key': 'existing-value', '--new-key': 'new-value'},
)

def test_override_existing_argument(self):
final_args = {'--common-key': 'old-value'}
user_args = {'--common-key': 'new-value'}
process_gcloud_args(user_args, final_args)
self.assertEqual(final_args, {'--common-key': 'new-value'})

def test_no_enable_flag_overrides_enable(self):
final_args = {'--enable-logging': True}
user_args = {'--no-enable-logging': True}
process_gcloud_args(user_args, final_args)
self.assertEqual(final_args, {'--no-enable-logging': True})
self.assertNotIn('--enable-logging', final_args)

def test_enable_flag_overrides_no_enable(self):
final_args = {'--no-enable-monitoring': True}
user_args = {'--enable-monitoring': True}
process_gcloud_args(user_args, final_args)
self.assertEqual(final_args, {'--enable-monitoring': True})
self.assertNotIn('--no-enable-monitoring', final_args)

def test_no_conflict(self):
final_args = {'--param1': 'value1'}
user_args = {'--param2': 'value2'}
process_gcloud_args(user_args, final_args)
self.assertEqual(final_args, {'--param1': 'value1', '--param2': 'value2'})

def test_empty_user_args(self):
final_args = {'--param1': 'value1'}
user_args = {}
process_gcloud_args(user_args, final_args)
self.assertEqual(final_args, {'--param1': 'value1'})

def test_complex_overrides(self):
final_args = {
'--zone': 'us-east1-b',
'--enable-ip-alias': True,
'--machine-type': 'n1-standard-4',
'--no-enable-public-ip': (
True # This will be removed if --enable-public-ip is set
),
}
user_args = {
'--zone': 'us-central1-a', # Overrides
'--no-enable-ip-alias': True, # Overrides --enable-ip-alias
'--disk-size': '200GB', # New
'--enable-public-ip': True, # Overrides --no-enable-public-ip
}
process_gcloud_args(user_args, final_args)
self.assertEqual(
final_args,
{
'--zone': 'us-central1-a',
'--no-enable-ip-alias': True,
'--machine-type': 'n1-standard-4', # Not affected
'--disk-size': '200GB',
'--enable-public-ip': True,
},
)
self.assertNotIn('--enable-ip-alias', final_args)
self.assertNotIn('--no-enable-public-ip', final_args)

def test_disable_flag_is_added(self):
"""
Tests that a --disable- flag from user_args is simply added to final_args
when no conflicting --enable- or --no-enable- flag exists.
"""
final_args = {'--existing-flag': 'value'}
user_args = {'--disable-dataplane-v2': True}
process_gcloud_args(user_args, final_args)
self.assertEqual(
final_args, {'--existing-flag': 'value', '--disable-dataplane-v2': True}
)

def test_enable_flag_overrides_disable(self):
"""
Tests that an --enable- flag from user_args overrides a --disable- flag
present in final_args.
"""
final_args = {'--disable-logging': True} # Existing disable flag
user_args = {'--enable-logging': True} # User wants to enable
process_gcloud_args(user_args, final_args)
self.assertEqual(final_args, {'--enable-logging': True})
self.assertNotIn('--disable-logging', final_args)

def test_disable_flag_overrides_enable_from_user_args_order(self):
"""
Tests that if --enable- is in final_args and --disable- is in user_args,
the --disable- from user_args takes precedence and removes the --enable-.
This is implied by the order of processing (user_args overwrite/remove final_args).
"""
final_args = {'--enable-some-feature': True}
user_args = {'--disable-some-feature': True}
process_gcloud_args(user_args, final_args)
self.assertEqual(final_args, {'--disable-some-feature': True})
self.assertNotIn('--enable-some-feature', final_args)


if __name__ == '__main__':
# Run python3 -m src.xpk.commands.tests.unit.test_gcloud_arg_processor under the xpk folder.
unittest.main()
Loading