1515# limitations under the License.
1616#
1717
18+ import importlib
1819import os
1920import pickle
2021import tempfile
4546 "save_method" : "_save_sklearn_model" ,
4647 "load_method" : "_load_sklearn_model" ,
4748 "model_file" : "model.pkl" ,
48- }
49+ },
50+ "xgboost" : {
51+ "save_method" : "_save_xgboost_model" ,
52+ "load_method" : "_load_xgboost_model" ,
53+ "model_file" : "model.bst" ,
54+ },
4955}
5056
5157
5258def save_model (
53- model : "sklearn.base.BaseEstimator" , # noqa: F821
59+ model : Union [ "sklearn.base.BaseEstimator" , "xgb.Booster" ] , # noqa: F821
5460 artifact_id : Optional [str ] = None ,
5561 * ,
5662 uri : Optional [str ] = None ,
@@ -63,7 +69,7 @@ def save_model(
6369) -> google_artifact_schema .ExperimentModel :
6470 """Saves a ML model into a MLMD artifact.
6571
66- Supported model frameworks: sklearn.
72+ Supported model frameworks: sklearn, xgboost .
6773
6874 Example usage:
6975 aiplatform.init(project="my-project", location="my-location", staging_bucket="gs://my-bucket")
@@ -72,7 +78,7 @@ def save_model(
7278 aiplatform.save_model(model, "my-sklearn-model")
7379
7480 Args:
75- model (sklearn.base.BaseEstimator):
81+ model (Union[" sklearn.base.BaseEstimator", "xgb.Booster"] ):
7682 Required. A machine learning model.
7783 artifact_id (str):
7884 Optional. The resource id of the artifact. This id must be globally unique
@@ -116,10 +122,23 @@ def save_model(
116122 except ImportError :
117123 pass
118124 else :
119- if isinstance (model , sklearn .base .BaseEstimator ):
125+ # An instance of sklearn.base.BaseEstimator might be a sklearn model
126+ # or a xgboost/lightgbm model implemented on top of sklearn.
127+ if isinstance (
128+ model , sklearn .base .BaseEstimator
129+ ) and model .__class__ .__module__ .startswith ("sklearn" ):
120130 framework_name = "sklearn"
121131 framework_version = sklearn .__version__
122132
133+ try :
134+ import xgboost as xgb
135+ except ImportError :
136+ pass
137+ else :
138+ if isinstance (model , (xgb .Booster , xgb .XGBModel )):
139+ framework_name = "xgboost"
140+ framework_version = xgb .__version__
141+
123142 if framework_name not in _FRAMEWORK_SPECS :
124143 raise ValueError (
125144 f"Model type { model .__class__ .__module__ } .{ model .__class__ .__name__ } not supported."
@@ -305,9 +324,24 @@ def _save_sklearn_model(
305324 pickle .dump (model , f , protocol = _PICKLE_PROTOCOL )
306325
307326
327+ def _save_xgboost_model (
328+ model : Union ["xgb.Booster" , "xgb.XGBModel" ], # noqa: F821
329+ path : str ,
330+ ):
331+ """Saves a xgboost model.
332+
333+ Args:
334+ model (Union[xgb.Booster, xgb.XGBModel]):
335+ Requred. A xgboost model.
336+ path (str):
337+ Required. The local path to save the model.
338+ """
339+ model .save_model (path )
340+
341+
308342def load_model (
309343 model : Union [str , google_artifact_schema .ExperimentModel ]
310- ) -> "sklearn.base.BaseEstimator" : # noqa: F821
344+ ) -> Union [ "sklearn.base.BaseEstimator" , "xgb.Booster" ] : # noqa: F821
311345 """Retrieves the original ML model from an ExperimentModel resource.
312346
313347 Args:
@@ -375,7 +409,44 @@ def _load_sklearn_model(
375409 return sk_model
376410
377411
378- # TODO(b/264893283)
412+ def _load_xgboost_model (
413+ model_file : str ,
414+ model_artifact : google_artifact_schema .ExperimentModel ,
415+ ) -> Union ["xgb.Booster" , "xgb.XGBModel" ]: # noqa: F821
416+ """Loads a xgboost model from local path.
417+
418+ Args:
419+ model_file (str):
420+ Required. A local model file to load.
421+ model_artifact (google_artifact_schema.ExperimentModel):
422+ Required. The artifact that saved the model.
423+ Returns:
424+ The xgboost model instance.
425+
426+ Raises:
427+ ImportError: if xgboost is not installed.
428+ """
429+ try :
430+ import xgboost as xgb
431+ except ImportError :
432+ raise ImportError (
433+ "xgboost is not installed and is required for loading models."
434+ ) from None
435+
436+ if xgb .__version__ < model_artifact .framework_version :
437+ _LOGGER .warning (
438+ f"The original model was saved via xgboost { model_artifact .framework_version } . "
439+ f"You are using xgboost { xgb .__version__ } ."
440+ "Attempting to load model..."
441+ )
442+
443+ module , class_name = model_artifact .model_class .rsplit ("." , maxsplit = 1 )
444+ xgb_model = getattr (importlib .import_module (module ), class_name )()
445+ xgb_model .load_model (model_file )
446+
447+ return xgb_model
448+
449+
379450def register_model (
380451 model : Union [str , google_artifact_schema .ExperimentModel ],
381452 * ,
0 commit comments