2525import vertexai
2626from google .cloud .aiplatform import base as aiplatform_base
2727from google .cloud .aiplatform import initializer as aiplatform_initializer
28- from google .cloud .aiplatform .compat .services import job_service_client
28+ from google .cloud .aiplatform .compat .services import (
29+ job_service_client ,
30+ model_service_client ,
31+ )
2932from google .cloud .aiplatform .compat .types import (
3033 batch_prediction_job as gca_batch_prediction_job_compat ,
3134 io as gca_io_compat ,
3235 job_state as gca_job_state_compat ,
36+ model as gca_model ,
3337)
3438from vertexai .preview import batch_prediction
3539from vertexai .generative_models import GenerativeModel
4347
4448_TEST_GEMINI_MODEL_NAME = "gemini-1.0-pro"
4549_TEST_GEMINI_MODEL_RESOURCE_NAME = f"publishers/google/models/{ _TEST_GEMINI_MODEL_NAME } "
50+ _TEST_TUNED_GEMINI_MODEL_RESOURCE_NAME = "projects/123/locations/us-central1/models/456"
4651_TEST_PALM_MODEL_NAME = "text-bison"
4752_TEST_PALM_MODEL_RESOURCE_NAME = f"publishers/google/models/{ _TEST_PALM_MODEL_NAME } "
4853
@@ -122,6 +127,48 @@ def get_batch_prediction_job_with_gcs_output_mock():
122127 yield get_job_mock
123128
124129
130+ @pytest .fixture
131+ def get_batch_prediction_job_with_tuned_gemini_model_mock ():
132+ with mock .patch .object (
133+ job_service_client .JobServiceClient , "get_batch_prediction_job"
134+ ) as get_job_mock :
135+ get_job_mock .return_value = gca_batch_prediction_job_compat .BatchPredictionJob (
136+ name = _TEST_BATCH_PREDICTION_JOB_NAME ,
137+ display_name = _TEST_DISPLAY_NAME ,
138+ model = _TEST_TUNED_GEMINI_MODEL_RESOURCE_NAME ,
139+ state = _TEST_JOB_STATE_SUCCESS ,
140+ output_info = gca_batch_prediction_job_compat .BatchPredictionJob .OutputInfo (
141+ gcs_output_directory = _TEST_GCS_OUTPUT_PREFIX
142+ ),
143+ )
144+ yield get_job_mock
145+
146+
147+ @pytest .fixture
148+ def get_gemini_model_mock ():
149+ with mock .patch .object (
150+ model_service_client .ModelServiceClient , "get_model"
151+ ) as get_model_mock :
152+ get_model_mock .return_value = gca_model .Model (
153+ name = _TEST_TUNED_GEMINI_MODEL_RESOURCE_NAME ,
154+ model_source_info = gca_model .ModelSourceInfo (
155+ source_type = gca_model .ModelSourceInfo .ModelSourceType .GENIE
156+ ),
157+ )
158+ yield get_model_mock
159+
160+
161+ @pytest .fixture
162+ def get_non_gemini_model_mock ():
163+ with mock .patch .object (
164+ model_service_client .ModelServiceClient , "get_model"
165+ ) as get_model_mock :
166+ get_model_mock .return_value = gca_model .Model (
167+ name = _TEST_TUNED_GEMINI_MODEL_RESOURCE_NAME ,
168+ )
169+ yield get_model_mock
170+
171+
125172@pytest .fixture
126173def get_batch_prediction_job_invalid_model_mock ():
127174 with mock .patch .object (
@@ -205,6 +252,21 @@ def test_init_batch_prediction_job(
205252 name = _TEST_BATCH_PREDICTION_JOB_NAME , retry = aiplatform_base ._DEFAULT_RETRY
206253 )
207254
255+ def test_init_batch_prediction_job_with_tuned_gemini_model (
256+ self ,
257+ get_batch_prediction_job_with_tuned_gemini_model_mock ,
258+ get_gemini_model_mock ,
259+ ):
260+ batch_prediction .BatchPredictionJob (_TEST_BATCH_PREDICTION_JOB_ID )
261+
262+ get_batch_prediction_job_with_tuned_gemini_model_mock .assert_called_once_with (
263+ name = _TEST_BATCH_PREDICTION_JOB_NAME , retry = aiplatform_base ._DEFAULT_RETRY
264+ )
265+ get_gemini_model_mock .assert_called_once_with (
266+ name = _TEST_TUNED_GEMINI_MODEL_RESOURCE_NAME ,
267+ retry = aiplatform_base ._DEFAULT_RETRY ,
268+ )
269+
208270 @pytest .mark .usefixtures ("get_batch_prediction_job_invalid_model_mock" )
209271 def test_init_batch_prediction_job_invalid_model (self ):
210272 with pytest .raises (
@@ -217,6 +279,23 @@ def test_init_batch_prediction_job_invalid_model(self):
217279 ):
218280 batch_prediction .BatchPredictionJob (_TEST_BATCH_PREDICTION_JOB_ID )
219281
282+ @pytest .mark .usefixtures (
283+ "get_batch_prediction_job_with_tuned_gemini_model_mock" ,
284+ "get_non_gemini_model_mock" ,
285+ )
286+ def test_init_batch_prediction_job_with_invalid_tuned_model (
287+ self ,
288+ ):
289+ with pytest .raises (
290+ ValueError ,
291+ match = (
292+ f"BatchPredictionJob '{ _TEST_BATCH_PREDICTION_JOB_ID } ' "
293+ f"runs with the model '{ _TEST_TUNED_GEMINI_MODEL_RESOURCE_NAME } ', "
294+ "which is not a GenAI model."
295+ ),
296+ ):
297+ batch_prediction .BatchPredictionJob (_TEST_BATCH_PREDICTION_JOB_ID )
298+
220299 @pytest .mark .usefixtures ("get_batch_prediction_job_with_gcs_output_mock" )
221300 def test_submit_batch_prediction_job_with_gcs_input (
222301 self , create_batch_prediction_job_mock
@@ -368,16 +447,59 @@ def test_submit_batch_prediction_job_with_bq_input_without_output_uri_prefix(
368447 timeout = None ,
369448 )
370449
450+ @pytest .mark .usefixtures ("create_batch_prediction_job_mock" )
451+ def test_submit_batch_prediction_job_with_tuned_model (
452+ self ,
453+ get_gemini_model_mock ,
454+ ):
455+ job = batch_prediction .BatchPredictionJob .submit (
456+ source_model = _TEST_TUNED_GEMINI_MODEL_RESOURCE_NAME ,
457+ input_dataset = _TEST_BQ_INPUT_URI ,
458+ )
459+
460+ assert job .gca_resource == _TEST_GAPIC_BATCH_PREDICTION_JOB
461+ get_gemini_model_mock .assert_called_once_with (
462+ name = _TEST_TUNED_GEMINI_MODEL_RESOURCE_NAME ,
463+ retry = aiplatform_base ._DEFAULT_RETRY ,
464+ )
465+
371466 def test_submit_batch_prediction_job_with_invalid_source_model (self ):
372467 with pytest .raises (
373468 ValueError ,
374- match = (f"Model '{ _TEST_PALM_MODEL_RESOURCE_NAME } ' is not a GenAI model." ),
469+ match = (
470+ f"Model '{ _TEST_PALM_MODEL_RESOURCE_NAME } ' is not a Generative AI model."
471+ ),
375472 ):
376473 batch_prediction .BatchPredictionJob .submit (
377474 source_model = _TEST_PALM_MODEL_NAME ,
378475 input_dataset = _TEST_GCS_INPUT_URI ,
379476 )
380477
478+ @pytest .mark .usefixtures ("get_non_gemini_model_mock" )
479+ def test_submit_batch_prediction_job_with_non_gemini_tuned_model (self ):
480+ with pytest .raises (
481+ ValueError ,
482+ match = (
483+ f"Model '{ _TEST_TUNED_GEMINI_MODEL_RESOURCE_NAME } ' "
484+ "is not a Generative AI model."
485+ ),
486+ ):
487+ batch_prediction .BatchPredictionJob .submit (
488+ source_model = _TEST_TUNED_GEMINI_MODEL_RESOURCE_NAME ,
489+ input_dataset = _TEST_GCS_INPUT_URI ,
490+ )
491+
492+ def test_submit_batch_prediction_job_with_invalid_model_name (self ):
493+ invalid_model_name = "invalid/model/name"
494+ with pytest .raises (
495+ ValueError ,
496+ match = (f"Invalid format for model name: { invalid_model_name } ." ),
497+ ):
498+ batch_prediction .BatchPredictionJob .submit (
499+ source_model = invalid_model_name ,
500+ input_dataset = _TEST_GCS_INPUT_URI ,
501+ )
502+
381503 def test_submit_batch_prediction_job_with_invalid_input_dataset (self ):
382504 with pytest .raises (
383505 ValueError ,
0 commit comments