Skip to content

Commit b350ac4

Browse files
KevinBNaughtonGoogle Cloud Pipeline Components maintainers
authored and
Google Cloud Pipeline Components maintainers
committed
feat(components): Upgrade LLM evaluation classification and text generation pipelines to preview
PiperOrigin-RevId: 555540517
1 parent 2f32b23 commit b350ac4

File tree

2 files changed

+27
-29
lines changed

2 files changed

+27
-29
lines changed
Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,16 @@
1515

1616

1717
@dsl.pipeline(name=_PIPELINE_NAME)
18-
def llm_eval_classification_pipeline( # pylint: disable=dangerous-default-value
18+
def evaluation_llm_classification_pipeline( # pylint: disable=dangerous-default-value
1919
project: str,
2020
location: str,
2121
target_field_name: str,
22+
batch_predict_gcs_source_uris: List[str],
2223
batch_predict_gcs_destination_output_uri: str,
2324
model_name: str = 'publishers/google/models/text-bison@001',
2425
evaluation_task: str = 'text-classification',
2526
evaluation_class_labels: List[str] = [],
2627
batch_predict_instances_format: str = 'jsonl',
27-
batch_predict_gcs_source_uris: List[str] = [],
2828
batch_predict_predictions_format: str = 'jsonl',
2929
machine_type: str = 'e2-highmem-16',
3030
service_account: str = '',
@@ -49,6 +49,13 @@ def llm_eval_classification_pipeline( # pylint: disable=dangerous-default-value
4949
target_field_name: The target field's name. Formatted to be able to find
5050
nested columns, delimited by ``.``. Prefixed with 'instance.' on the
5151
component for Vertex Batch Prediction.
52+
batch_predict_gcs_source_uris: Google Cloud Storage URI(-s) to your
53+
instances data to run batch prediction on. The instances data should also
54+
contain the ground truth (target) data, used for evaluation. May contain
55+
wildcards. For more information on wildcards, see
56+
https://cloud.google.com/storage/docs/gsutil/addlhelp/WildcardNames. For
57+
more details about this input config, see
58+
https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.batchPredictionJobs#InputConfig.
5259
batch_predict_gcs_destination_output_uri: The Google Cloud Storage location
5360
of the directory where the output is to be written to.
5461
model_name: The Model name used to run evaluation. Must be a publisher Model
@@ -65,13 +72,6 @@ def llm_eval_classification_pipeline( # pylint: disable=dangerous-default-value
6572
must be one of the Model's supportedInputStorageFormats. For more details
6673
about this input config, see
6774
https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.batchPredictionJobs#InputConfig.
68-
batch_predict_gcs_source_uris: Google Cloud Storage URI(-s) to your
69-
instances data to run batch prediction on. The instances data should also
70-
contain the ground truth (target) data, used for evaluation. May contain
71-
wildcards. For more information on wildcards, see
72-
https://cloud.google.com/storage/docs/gsutil/addlhelp/WildcardNames. For
73-
more details about this input config, see
74-
https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.batchPredictionJobs#InputConfig.
7575
batch_predict_predictions_format: The format in which Vertex AI gives the
7676
predictions. Must be one of the Model's supportedOutputStorageFormats. For
7777
more details about this output config, see
@@ -113,19 +113,18 @@ def llm_eval_classification_pipeline( # pylint: disable=dangerous-default-value
113113
created.
114114
115115
Returns:
116-
NamedTuple:
117-
evaluation_metrics: ClassificationMetrics Artifact for LLM Text
118-
Classification.
119-
evaluation_resource_name: If run on an user's managed VertexModel, the
120-
imported evaluation resource name. Empty if run on a publisher model.
116+
evaluation_metrics: ClassificationMetrics Artifact for LLM Text
117+
Classification.
118+
evaluation_resource_name: If run on an user's managed VertexModel, the
119+
imported evaluation resource name. Empty if run on a publisher model.
121120
"""
122121
outputs = NamedTuple(
123122
'outputs',
124123
evaluation_metrics=ClassificationMetrics,
125124
evaluation_resource_name=str,
126125
)
127126

128-
get_vertex_model_task = dsl.importer_node.importer(
127+
get_vertex_model_task = dsl.importer(
129128
artifact_uri=(
130129
f'https://{location}-aiplatform.googleapis.com/v1/{model_name}'
131130
),
Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,14 @@
1414

1515

1616
@dsl.pipeline(name=_PIPELINE_NAME)
17-
def llm_eval_text_generation_pipeline( # pylint: disable=dangerous-default-value
17+
def evaluation_llm_text_generation_pipeline( # pylint: disable=dangerous-default-value
1818
project: str,
1919
location: str,
20+
batch_predict_gcs_source_uris: List[str],
2021
batch_predict_gcs_destination_output_uri: str,
2122
model_name: str = 'publishers/google/models/text-bison@001',
2223
evaluation_task: str = 'text-generation',
2324
batch_predict_instances_format: str = 'jsonl',
24-
batch_predict_gcs_source_uris: List[str] = [],
2525
batch_predict_predictions_format: str = 'jsonl',
2626
machine_type: str = 'e2-highmem-16',
2727
service_account: str = '',
@@ -39,6 +39,13 @@ def llm_eval_text_generation_pipeline( # pylint: disable=dangerous-default-valu
3939
Args:
4040
project: The GCP project that runs the pipeline components.
4141
location: The GCP region that runs the pipeline components.
42+
batch_predict_gcs_source_uris: Google Cloud Storage URI(-s) to your
43+
instances data to run batch prediction on. The instances data should also
44+
contain the ground truth (target) data, used for evaluation. May contain
45+
wildcards. For more information on wildcards, see
46+
https://cloud.google.com/storage/docs/gsutil/addlhelp/WildcardNames. For
47+
more details about this input config, see
48+
https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.batchPredictionJobs#InputConfig.
4249
batch_predict_gcs_destination_output_uri: The Google Cloud Storage location
4350
of the directory where the output is to be written to.
4451
model_name: The Model name used to run evaluation. Must be a publisher Model
@@ -53,13 +60,6 @@ def llm_eval_text_generation_pipeline( # pylint: disable=dangerous-default-valu
5360
must be one of the Model's supportedInputStorageFormats. Only "jsonl" is
5461
currently supported. For more details about this input config, see
5562
https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.batchPredictionJobs#InputConfig.
56-
batch_predict_gcs_source_uris: Google Cloud Storage URI(-s) to your
57-
instances data to run batch prediction on. The instances data should also
58-
contain the ground truth (target) data, used for evaluation. May contain
59-
wildcards. For more information on wildcards, see
60-
https://cloud.google.com/storage/docs/gsutil/addlhelp/WildcardNames. For
61-
more details about this input config, see
62-
https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.batchPredictionJobs#InputConfig.
6363
batch_predict_predictions_format: The format in which Vertex AI gives the
6464
predictions. Must be one of the Model's supportedOutputStorageFormats.
6565
Only "jsonl" is currently supported. For more details about this output
@@ -91,18 +91,17 @@ def llm_eval_text_generation_pipeline( # pylint: disable=dangerous-default-valu
9191
created.
9292
9393
Returns:
94-
NamedTuple:
95-
evaluation_metrics: Metrics Artifact for LLM Text Generation.
96-
evaluation_resource_name: If run on an user's managed VertexModel, the
97-
imported evaluation resource name. Empty if run on a publisher model.
94+
evaluation_metrics: Metrics Artifact for LLM Text Generation.
95+
evaluation_resource_name: If run on an user's managed VertexModel, the
96+
imported evaluation resource name. Empty if run on a publisher model.
9897
"""
9998
outputs = NamedTuple(
10099
'outputs',
101100
evaluation_metrics=Metrics,
102101
evaluation_resource_name=str,
103102
)
104103

105-
get_vertex_model_task = dsl.importer_node.importer(
104+
get_vertex_model_task = dsl.importer(
106105
artifact_uri=(
107106
f'https://{location}-aiplatform.googleapis.com/v1/{model_name}'
108107
),

0 commit comments

Comments
 (0)