|
1 | 1 | # -*- coding: utf-8 -*- |
2 | 2 |
|
3 | | -# Copyright 2023 Google LLC |
| 3 | +# Copyright 2024 Google LLC |
4 | 4 | # |
5 | 5 | # Licensed under the Apache License, Version 2.0 (the "License"); |
6 | 6 | # you may not use this file except in compliance with the License. |
|
16 | 16 | # |
17 | 17 | import grpc |
18 | 18 | import logging |
| 19 | +import ray |
| 20 | + |
19 | 21 | from typing import Dict |
20 | 22 | from typing import Optional |
21 | 23 | from google.cloud import aiplatform |
@@ -45,16 +47,30 @@ def __init__( |
45 | 47 | persistent_resource_id, |
46 | 48 | " failed to start Head node properly.", |
47 | 49 | ) |
48 | | - |
49 | | - super().__init__( |
50 | | - dashboard_url=dashboard_uri, |
51 | | - python_version=ray_client_context.python_version, |
52 | | - ray_version=ray_client_context.ray_version, |
53 | | - ray_commit=ray_client_context.ray_commit, |
54 | | - protocol_version=ray_client_context.protocol_version, |
55 | | - _num_clients=ray_client_context._num_clients, |
56 | | - _context_to_restore=ray_client_context._context_to_restore, |
57 | | - ) |
| 50 | + if ray.__version__ == "2.33.0": |
| 51 | + super().__init__( |
| 52 | + dashboard_url=dashboard_uri, |
| 53 | + python_version=ray_client_context.python_version, |
| 54 | + ray_version=ray_client_context.ray_version, |
| 55 | + ray_commit=ray_client_context.ray_commit, |
| 56 | + _num_clients=ray_client_context._num_clients, |
| 57 | + _context_to_restore=ray_client_context._context_to_restore, |
| 58 | + ) |
| 59 | + elif ray.__version__ == "2.9.3": |
| 60 | + super().__init__( |
| 61 | + dashboard_url=dashboard_uri, |
| 62 | + python_version=ray_client_context.python_version, |
| 63 | + ray_version=ray_client_context.ray_version, |
| 64 | + ray_commit=ray_client_context.ray_commit, |
| 65 | + protocol_version=ray_client_context.protocol_version, |
| 66 | + _num_clients=ray_client_context._num_clients, |
| 67 | + _context_to_restore=ray_client_context._context_to_restore, |
| 68 | + ) |
| 69 | + else: |
| 70 | + raise ImportError( |
| 71 | + f"[Ray on Vertex AI]: Unsupported version {ray.__version__}." |
| 72 | + + "Only 2.33.0 and 2.9.3 are supported." |
| 73 | + ) |
58 | 74 | self.persistent_resource_id = persistent_resource_id |
59 | 75 | self.vertex_sdk_version = str(VERTEX_SDK_VERSION) |
60 | 76 | self.shell_uri = ray_head_uris.get("RAY_HEAD_NODE_INTERACTIVE_SHELL_URI") |
|
0 commit comments