Skip to content

Commit a90ee8d

Browse files
jaycee-licopybara-github
authored andcommitted
feat: GenAI - Batch Prediction - Added support for tuned GenAI models
PiperOrigin-RevId: 646136098
1 parent a31ac4d commit a90ee8d

File tree

2 files changed

+163
-17
lines changed

2 files changed

+163
-17
lines changed

tests/unit/vertexai/test_batch_prediction.py

Lines changed: 124 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,15 @@
2525
import vertexai
2626
from google.cloud.aiplatform import base as aiplatform_base
2727
from google.cloud.aiplatform import initializer as aiplatform_initializer
28-
from google.cloud.aiplatform.compat.services import job_service_client
28+
from google.cloud.aiplatform.compat.services import (
29+
job_service_client,
30+
model_service_client,
31+
)
2932
from google.cloud.aiplatform.compat.types import (
3033
batch_prediction_job as gca_batch_prediction_job_compat,
3134
io as gca_io_compat,
3235
job_state as gca_job_state_compat,
36+
model as gca_model,
3337
)
3438
from vertexai.preview import batch_prediction
3539
from vertexai.generative_models import GenerativeModel
@@ -43,6 +47,7 @@
4347

4448
_TEST_GEMINI_MODEL_NAME = "gemini-1.0-pro"
4549
_TEST_GEMINI_MODEL_RESOURCE_NAME = f"publishers/google/models/{_TEST_GEMINI_MODEL_NAME}"
50+
_TEST_TUNED_GEMINI_MODEL_RESOURCE_NAME = "projects/123/locations/us-central1/models/456"
4651
_TEST_PALM_MODEL_NAME = "text-bison"
4752
_TEST_PALM_MODEL_RESOURCE_NAME = f"publishers/google/models/{_TEST_PALM_MODEL_NAME}"
4853

@@ -122,6 +127,48 @@ def get_batch_prediction_job_with_gcs_output_mock():
122127
yield get_job_mock
123128

124129

130+
@pytest.fixture
131+
def get_batch_prediction_job_with_tuned_gemini_model_mock():
132+
with mock.patch.object(
133+
job_service_client.JobServiceClient, "get_batch_prediction_job"
134+
) as get_job_mock:
135+
get_job_mock.return_value = gca_batch_prediction_job_compat.BatchPredictionJob(
136+
name=_TEST_BATCH_PREDICTION_JOB_NAME,
137+
display_name=_TEST_DISPLAY_NAME,
138+
model=_TEST_TUNED_GEMINI_MODEL_RESOURCE_NAME,
139+
state=_TEST_JOB_STATE_SUCCESS,
140+
output_info=gca_batch_prediction_job_compat.BatchPredictionJob.OutputInfo(
141+
gcs_output_directory=_TEST_GCS_OUTPUT_PREFIX
142+
),
143+
)
144+
yield get_job_mock
145+
146+
147+
@pytest.fixture
148+
def get_gemini_model_mock():
149+
with mock.patch.object(
150+
model_service_client.ModelServiceClient, "get_model"
151+
) as get_model_mock:
152+
get_model_mock.return_value = gca_model.Model(
153+
name=_TEST_TUNED_GEMINI_MODEL_RESOURCE_NAME,
154+
model_source_info=gca_model.ModelSourceInfo(
155+
source_type=gca_model.ModelSourceInfo.ModelSourceType.GENIE
156+
),
157+
)
158+
yield get_model_mock
159+
160+
161+
@pytest.fixture
162+
def get_non_gemini_model_mock():
163+
with mock.patch.object(
164+
model_service_client.ModelServiceClient, "get_model"
165+
) as get_model_mock:
166+
get_model_mock.return_value = gca_model.Model(
167+
name=_TEST_TUNED_GEMINI_MODEL_RESOURCE_NAME,
168+
)
169+
yield get_model_mock
170+
171+
125172
@pytest.fixture
126173
def get_batch_prediction_job_invalid_model_mock():
127174
with mock.patch.object(
@@ -205,6 +252,21 @@ def test_init_batch_prediction_job(
205252
name=_TEST_BATCH_PREDICTION_JOB_NAME, retry=aiplatform_base._DEFAULT_RETRY
206253
)
207254

255+
def test_init_batch_prediction_job_with_tuned_gemini_model(
256+
self,
257+
get_batch_prediction_job_with_tuned_gemini_model_mock,
258+
get_gemini_model_mock,
259+
):
260+
batch_prediction.BatchPredictionJob(_TEST_BATCH_PREDICTION_JOB_ID)
261+
262+
get_batch_prediction_job_with_tuned_gemini_model_mock.assert_called_once_with(
263+
name=_TEST_BATCH_PREDICTION_JOB_NAME, retry=aiplatform_base._DEFAULT_RETRY
264+
)
265+
get_gemini_model_mock.assert_called_once_with(
266+
name=_TEST_TUNED_GEMINI_MODEL_RESOURCE_NAME,
267+
retry=aiplatform_base._DEFAULT_RETRY,
268+
)
269+
208270
@pytest.mark.usefixtures("get_batch_prediction_job_invalid_model_mock")
209271
def test_init_batch_prediction_job_invalid_model(self):
210272
with pytest.raises(
@@ -217,6 +279,23 @@ def test_init_batch_prediction_job_invalid_model(self):
217279
):
218280
batch_prediction.BatchPredictionJob(_TEST_BATCH_PREDICTION_JOB_ID)
219281

282+
@pytest.mark.usefixtures(
283+
"get_batch_prediction_job_with_tuned_gemini_model_mock",
284+
"get_non_gemini_model_mock",
285+
)
286+
def test_init_batch_prediction_job_with_invalid_tuned_model(
287+
self,
288+
):
289+
with pytest.raises(
290+
ValueError,
291+
match=(
292+
f"BatchPredictionJob '{_TEST_BATCH_PREDICTION_JOB_ID}' "
293+
f"runs with the model '{_TEST_TUNED_GEMINI_MODEL_RESOURCE_NAME}', "
294+
"which is not a GenAI model."
295+
),
296+
):
297+
batch_prediction.BatchPredictionJob(_TEST_BATCH_PREDICTION_JOB_ID)
298+
220299
@pytest.mark.usefixtures("get_batch_prediction_job_with_gcs_output_mock")
221300
def test_submit_batch_prediction_job_with_gcs_input(
222301
self, create_batch_prediction_job_mock
@@ -368,16 +447,59 @@ def test_submit_batch_prediction_job_with_bq_input_without_output_uri_prefix(
368447
timeout=None,
369448
)
370449

450+
@pytest.mark.usefixtures("create_batch_prediction_job_mock")
451+
def test_submit_batch_prediction_job_with_tuned_model(
452+
self,
453+
get_gemini_model_mock,
454+
):
455+
job = batch_prediction.BatchPredictionJob.submit(
456+
source_model=_TEST_TUNED_GEMINI_MODEL_RESOURCE_NAME,
457+
input_dataset=_TEST_BQ_INPUT_URI,
458+
)
459+
460+
assert job.gca_resource == _TEST_GAPIC_BATCH_PREDICTION_JOB
461+
get_gemini_model_mock.assert_called_once_with(
462+
name=_TEST_TUNED_GEMINI_MODEL_RESOURCE_NAME,
463+
retry=aiplatform_base._DEFAULT_RETRY,
464+
)
465+
371466
def test_submit_batch_prediction_job_with_invalid_source_model(self):
372467
with pytest.raises(
373468
ValueError,
374-
match=(f"Model '{_TEST_PALM_MODEL_RESOURCE_NAME}' is not a GenAI model."),
469+
match=(
470+
f"Model '{_TEST_PALM_MODEL_RESOURCE_NAME}' is not a Generative AI model."
471+
),
375472
):
376473
batch_prediction.BatchPredictionJob.submit(
377474
source_model=_TEST_PALM_MODEL_NAME,
378475
input_dataset=_TEST_GCS_INPUT_URI,
379476
)
380477

478+
@pytest.mark.usefixtures("get_non_gemini_model_mock")
479+
def test_submit_batch_prediction_job_with_non_gemini_tuned_model(self):
480+
with pytest.raises(
481+
ValueError,
482+
match=(
483+
f"Model '{_TEST_TUNED_GEMINI_MODEL_RESOURCE_NAME}' "
484+
"is not a Generative AI model."
485+
),
486+
):
487+
batch_prediction.BatchPredictionJob.submit(
488+
source_model=_TEST_TUNED_GEMINI_MODEL_RESOURCE_NAME,
489+
input_dataset=_TEST_GCS_INPUT_URI,
490+
)
491+
492+
def test_submit_batch_prediction_job_with_invalid_model_name(self):
493+
invalid_model_name = "invalid/model/name"
494+
with pytest.raises(
495+
ValueError,
496+
match=(f"Invalid format for model name: {invalid_model_name}."),
497+
):
498+
batch_prediction.BatchPredictionJob.submit(
499+
source_model=invalid_model_name,
500+
input_dataset=_TEST_GCS_INPUT_URI,
501+
)
502+
381503
def test_submit_batch_prediction_job_with_invalid_input_dataset(self):
382504
with pytest.raises(
383505
ValueError,

vertexai/batch_prediction/_batch_prediction.py

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from google.cloud.aiplatform import base as aiplatform_base
2323
from google.cloud.aiplatform import initializer as aiplatform_initializer
2424
from google.cloud.aiplatform import jobs
25+
from google.cloud.aiplatform import models
2526
from google.cloud.aiplatform import utils as aiplatform_utils
2627
from google.cloud.aiplatform_v1 import types as gca_types
2728
from vertexai import generative_models
@@ -32,6 +33,7 @@
3233
_LOGGER = aiplatform_base.Logger(__name__)
3334

3435
_GEMINI_MODEL_PATTERN = r"publishers/google/models/gemini"
36+
_GEMINI_TUNED_MODEL_PATTERN = r"^projects/[0-9]+?/locations/[0-9a-z-]+?/models/[0-9]+?$"
3537

3638

3739
class BatchPredictionJob(aiplatform_base._VertexAiResourceNounPlus):
@@ -64,8 +66,7 @@ def __init__(self, batch_prediction_job_name: str):
6466
self._gca_resource = self._get_gca_resource(
6567
resource_name=batch_prediction_job_name
6668
)
67-
# TODO(b/338452508) Support tuned GenAI models.
68-
if not re.search(_GEMINI_MODEL_PATTERN, self.model_name):
69+
if not self._is_genai_model(self.model_name):
6970
raise ValueError(
7071
f"BatchPredictionJob '{batch_prediction_job_name}' "
7172
f"runs with the model '{self.model_name}', "
@@ -117,9 +118,12 @@ def submit(
117118
118119
Args:
119120
source_model (Union[str, generative_models.GenerativeModel]):
120-
Model name or a GenerativeModel instance for batch prediction.
121-
Supported formats: "gemini-1.0-pro", "models/gemini-1.0-pro",
122-
and "publishers/google/models/gemini-1.0-pro"
121+
A GenAI model name or a tuned model name or a GenerativeModel instance
122+
for batch prediction.
123+
Supported formats for model name: "gemini-1.0-pro",
124+
"models/gemini-1.0-pro", and "publishers/google/models/gemini-1.0-pro"
125+
Supported formats for tuned model name: "789" and
126+
"projects/123/locations/456/models/789"
123127
input_dataset (Union[str,List[str]]):
124128
GCS URI(-s) or Bigquery URI to your input data to run batch
125129
prediction on. Example: "gs://path/to/input/data.jsonl" or
@@ -142,12 +146,13 @@ def submit(
142146
set in vertexai.init().
143147
"""
144148
# Handle model name
145-
# TODO(b/338452508) Support tuned GenAI models.
146149
model_name = cls._reconcile_model_name(
147150
source_model._model_name
148151
if isinstance(source_model, generative_models.GenerativeModel)
149152
else source_model
150153
)
154+
if not cls._is_genai_model(model_name):
155+
raise ValueError(f"Model '{model_name}' is not a Generative AI model.")
151156

152157
# Handle input URI
153158
gcs_source = None
@@ -244,9 +249,7 @@ def delete(self):
244249
def list(cls, filter=None) -> List["BatchPredictionJob"]:
245250
"""Lists all BatchPredictionJob instances that run with GenAI models."""
246251
return cls._list(
247-
cls_filter=lambda gca_resource: re.search(
248-
_GEMINI_MODEL_PATTERN, gca_resource.model
249-
),
252+
cls_filter=lambda gca_resource: cls._is_genai_model(gca_resource.model),
250253
filter=filter,
251254
)
252255

@@ -263,23 +266,44 @@ def _dashboard_uri(self) -> Optional[str]:
263266

264267
@classmethod
265268
def _reconcile_model_name(cls, model_name: str) -> str:
266-
"""Reconciles model name to a publisher model resource name."""
269+
"""Reconciles model name to a publisher model resource name or a tuned model resource name."""
267270
if not model_name:
268271
raise ValueError("model_name must not be empty")
272+
269273
if "/" not in model_name:
274+
# model name (e.g., gemini-1.0-pro)
270275
model_name = "publishers/google/models/" + model_name
271276
elif model_name.startswith("models/"):
277+
# publisher model name (e.g., models/gemini-1.0-pro)
272278
model_name = "publishers/google/" + model_name
273-
elif not model_name.startswith("publishers/google/models/") and not re.search(
274-
r"^projects/.*?/locations/.*?/publishers/google/models/.*$", model_name
279+
elif (
280+
# publisher model full name
281+
not model_name.startswith("publishers/google/models/")
282+
# tuned model full resource name
283+
and not re.search(_GEMINI_TUNED_MODEL_PATTERN, model_name)
275284
):
276285
raise ValueError(f"Invalid format for model name: {model_name}.")
277286

278-
if not re.search(_GEMINI_MODEL_PATTERN, model_name):
279-
raise ValueError(f"Model '{model_name}' is not a GenAI model.")
280-
281287
return model_name
282288

289+
@classmethod
290+
def _is_genai_model(cls, model_name: str) -> bool:
291+
"""Validates if a given model_name represents a GenAI model."""
292+
if re.search(_GEMINI_MODEL_PATTERN, model_name):
293+
# Model is a Gemini model.
294+
return True
295+
296+
if re.search(_GEMINI_TUNED_MODEL_PATTERN, model_name):
297+
model = models.Model(model_name)
298+
if (
299+
model.gca_resource.model_source_info.source_type
300+
== gca_types.model.ModelSourceInfo.ModelSourceType.GENIE
301+
):
302+
# Model is a tuned Gemini model.
303+
return True
304+
305+
return False
306+
283307
@classmethod
284308
def _complete_bq_uri(cls, uri: Optional[str] = None):
285309
"""Completes a BigQuery uri to a BigQuery table uri."""

0 commit comments

Comments
 (0)