1515"""Classes for working with language models."""
1616
1717import dataclasses
18- from typing import Any , AsyncIterator , Dict , Iterator , List , Optional , Sequence , Union
18+ from typing import Any , AsyncIterator , Dict , Iterator , List , Literal , Optional , Sequence , Union
1919import warnings
2020
2121from google .cloud import aiplatform
4242# Endpoint label/metadata key to preserve the base model ID information
4343_TUNING_BASE_MODEL_ID_LABEL_KEY = "google-vertex-llm-tuning-base-model-id"
4444
45+ _ACCELERATOR_TYPES = ["TPU" , "GPU" ]
46+ _ACCELERATOR_TYPE_TYPE = Literal ["TPU" , "GPU" ]
47+
4548
4649def _get_model_id_from_tuning_model_id (tuning_model_id : str ) -> str :
4750 """Gets the base model ID for the model ID labels used the tuned models.
@@ -166,6 +169,7 @@ def tune_model(
166169 model_display_name : Optional [str ] = None ,
167170 tuning_evaluation_spec : Optional ["TuningEvaluationSpec" ] = None ,
168171 default_context : Optional [str ] = None ,
172+ accelerator_type : Optional [_ACCELERATOR_TYPE_TYPE ] = None ,
169173 ) -> "_LanguageModelTuningJob" :
170174 """Tunes a model based on training data.
171175
@@ -191,6 +195,7 @@ def tune_model(
191195 model_display_name: Custom display name for the tuned model.
192196 tuning_evaluation_spec: Specification for the model evaluation during tuning.
193197 default_context: The context to use for all training samples by default.
198+ accelerator_type: Type of accelerator to use. Can be "TPU" or "GPU".
194199
195200 Returns:
196201 A `LanguageModelTuningJob` object that represents the tuning job.
@@ -252,6 +257,14 @@ def tune_model(
252257 if default_context :
253258 tuning_parameters ["default_context" ] = default_context
254259
260+ if accelerator_type :
261+ if accelerator_type not in _ACCELERATOR_TYPES :
262+ raise ValueError (
263+ f"Unsupported accelerator type: { accelerator_type } ."
264+ f" Supported types: { _ACCELERATOR_TYPES } "
265+ )
266+ tuning_parameters ["accelerator_type" ] = accelerator_type
267+
255268 return self ._tune_model (
256269 training_data = training_data ,
257270 tuning_parameters = tuning_parameters ,
@@ -336,6 +349,7 @@ def tune_model(
336349 tuned_model_location : Optional [str ] = None ,
337350 model_display_name : Optional [str ] = None ,
338351 tuning_evaluation_spec : Optional ["TuningEvaluationSpec" ] = None ,
352+ accelerator_type : Optional [_ACCELERATOR_TYPE_TYPE ] = None ,
339353 ) -> "_LanguageModelTuningJob" :
340354 """Tunes a model based on training data.
341355
@@ -357,6 +371,7 @@ def tune_model(
357371 tuned_model_location: GCP location where the tuned model should be deployed. Only "us-central1" is supported for now.
358372 model_display_name: Custom display name for the tuned model.
359373 tuning_evaluation_spec: Specification for the model evaluation during tuning.
374+ accelerator_type: Type of accelerator to use. Can be "TPU" or "GPU".
360375
361376 Returns:
362377 A `LanguageModelTuningJob` object that represents the tuning job.
@@ -376,6 +391,7 @@ def tune_model(
376391 tuned_model_location = tuned_model_location ,
377392 model_display_name = model_display_name ,
378393 tuning_evaluation_spec = tuning_evaluation_spec ,
394+ accelerator_type = accelerator_type ,
379395 )
380396
381397
@@ -393,6 +409,7 @@ def tune_model(
393409 tuned_model_location : Optional [str ] = None ,
394410 model_display_name : Optional [str ] = None ,
395411 tuning_evaluation_spec : Optional ["TuningEvaluationSpec" ] = None ,
412+ accelerator_type : Optional [_ACCELERATOR_TYPE_TYPE ] = None ,
396413 ) -> "_LanguageModelTuningJob" :
397414 """Tunes a model based on training data.
398415
@@ -421,6 +438,7 @@ def tune_model(
421438 tuned_model_location: GCP location where the tuned model should be deployed. Only "us-central1" is supported for now.
422439 model_display_name: Custom display name for the tuned model.
423440 tuning_evaluation_spec: Specification for the model evaluation during tuning.
441+ accelerator_type: Type of accelerator to use. Can be "TPU" or "GPU".
424442
425443 Returns:
426444 A `LanguageModelTuningJob` object that represents the tuning job.
@@ -441,6 +459,7 @@ def tune_model(
441459 tuned_model_location = tuned_model_location ,
442460 model_display_name = model_display_name ,
443461 tuning_evaluation_spec = tuning_evaluation_spec ,
462+ accelerator_type = accelerator_type ,
444463 )
445464 tuned_model = job .get_tuned_model ()
446465 self ._endpoint = tuned_model ._endpoint
@@ -461,6 +480,7 @@ def tune_model(
461480 tuned_model_location : Optional [str ] = None ,
462481 model_display_name : Optional [str ] = None ,
463482 default_context : Optional [str ] = None ,
483+ accelerator_type : Optional [_ACCELERATOR_TYPE_TYPE ] = None ,
464484 ) -> "_LanguageModelTuningJob" :
465485 """Tunes a model based on training data.
466486
@@ -485,6 +505,7 @@ def tune_model(
485505 tuned_model_location: GCP location where the tuned model should be deployed. Only "us-central1" is supported for now.
486506 model_display_name: Custom display name for the tuned model.
487507 default_context: The context to use for all training samples by default.
508+ accelerator_type: Type of accelerator to use. Can be "TPU" or "GPU".
488509
489510 Returns:
490511 A `LanguageModelTuningJob` object that represents the tuning job.
@@ -504,6 +525,7 @@ def tune_model(
504525 tuned_model_location = tuned_model_location ,
505526 model_display_name = model_display_name ,
506527 default_context = default_context ,
528+ accelerator_type = accelerator_type ,
507529 )
508530
509531
@@ -521,6 +543,7 @@ def tune_model(
521543 tuned_model_location : Optional [str ] = None ,
522544 model_display_name : Optional [str ] = None ,
523545 default_context : Optional [str ] = None ,
546+ accelerator_type : Optional [_ACCELERATOR_TYPE_TYPE ] = None ,
524547 ) -> "_LanguageModelTuningJob" :
525548 """Tunes a model based on training data.
526549
@@ -549,6 +572,7 @@ def tune_model(
549572 tuned_model_location: GCP location where the tuned model should be deployed. Only "us-central1" is supported for now.
550573 model_display_name: Custom display name for the tuned model.
551574 default_context: The context to use for all training samples by default.
575+ accelerator_type: Type of accelerator to use. Can be "TPU" or "GPU".
552576
553577 Returns:
554578 A `LanguageModelTuningJob` object that represents the tuning job.
@@ -569,6 +593,7 @@ def tune_model(
569593 tuned_model_location = tuned_model_location ,
570594 model_display_name = model_display_name ,
571595 default_context = default_context ,
596+ accelerator_type = accelerator_type ,
572597 )
573598 tuned_model = job .get_tuned_model ()
574599 self ._endpoint = tuned_model ._endpoint
0 commit comments