Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(SIP-95): permissions for catalogs #28317

Merged
merged 3 commits into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
35 changes: 30 additions & 5 deletions superset/commands/database/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,37 @@ def run(self) -> Model:

db.session.commit()

# adding a new database we always want to force refresh schema list
schemas = database.get_all_schema_names(cache=False, ssh_tunnel=ssh_tunnel)
for schema in schemas:
security_manager.add_permission_view_menu(
"schema_access", security_manager.get_schema_perm(database, schema)
# add catalog/schema permissions
if database.db_engine_spec.supports_catalog:
catalogs = database.get_all_catalog_names(
cache=False,
ssh_tunnel=ssh_tunnel,
)
for catalog in catalogs:
security_manager.add_permission_view_menu(
"catalog_access",
security_manager.get_catalog_perm(
database.database_name, catalog
),
)
else:
# add a dummy catalog for DBs that don't support them
catalogs = [None]

for catalog in catalogs:
for schema in database.get_all_schema_names(
catalog=catalog,
cache=False,
ssh_tunnel=ssh_tunnel,
):
security_manager.add_permission_view_menu(
"schema_access",
security_manager.get_schema_perm(
database.database_name,
catalog,
schema,
),
)

except (
SSHTunnelInvalidError,
Expand Down
31 changes: 22 additions & 9 deletions superset/commands/database/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

import logging
from typing import Any, cast

Expand All @@ -29,16 +32,22 @@
from superset.exceptions import SupersetException
from superset.extensions import db, security_manager
from superset.models.core import Database
from superset.utils.core import DatasourceName

logger = logging.getLogger(__name__)


class TablesDatabaseCommand(BaseCommand):
_model: Database

def __init__(self, db_id: int, schema_name: str, force: bool):
def __init__(
self,
db_id: int,
catalog_name: str | None,
schema_name: str,
force: bool,
):
self._db_id = db_id
self._catalog_name = catalog_name
self._schema_name = schema_name
self._force = force

Expand All @@ -47,11 +56,11 @@ def run(self) -> dict[str, Any]:
try:
tables = security_manager.get_datasources_accessible_by_user(
database=self._model,
catalog=self._catalog_name,
schema=self._schema_name,
datasource_names=sorted(
DatasourceName(*datasource_name)
for datasource_name in self._model.get_all_table_names_in_schema(
catalog=None,
self._model.get_all_table_names_in_schema(
catalog=self._catalog_name,
schema=self._schema_name,
force=self._force,
cache=self._model.table_cache_enabled,
Expand All @@ -62,11 +71,11 @@ def run(self) -> dict[str, Any]:

views = security_manager.get_datasources_accessible_by_user(
database=self._model,
catalog=self._catalog_name,
schema=self._schema_name,
datasource_names=sorted(
DatasourceName(*datasource_name)
for datasource_name in self._model.get_all_view_names_in_schema(
catalog=None,
self._model.get_all_view_names_in_schema(
catalog=self._catalog_name,
schema=self._schema_name,
force=self._force,
cache=self._model.table_cache_enabled,
Expand All @@ -81,11 +90,15 @@ def run(self) -> dict[str, Any]:
db.session.query(SqlaTable)
.filter(
SqlaTable.database_id == self._model.id,
SqlaTable.catalog == self._catalog_name,
SqlaTable.schema == self._schema_name,
)
.options(
load_only(
SqlaTable.schema, SqlaTable.table_name, SqlaTable.extra
SqlaTable.catalog,
SqlaTable.schema,
SqlaTable.table_name,
SqlaTable.extra,
),
lazyload(SqlaTable.columns),
lazyload(SqlaTable.metrics),
Expand Down
202 changes: 167 additions & 35 deletions superset/commands/database/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,9 @@
from __future__ import annotations

import logging
from typing import Any, Optional
from typing import Any

from flask_appbuilder.models.sqla import Model
from marshmallow import ValidationError

from superset import is_feature_enabled, security_manager
from superset.commands.base import BaseCommand
Expand Down Expand Up @@ -50,12 +49,12 @@


class UpdateDatabaseCommand(BaseCommand):
_model: Optional[Database]
_model: Database | None

def __init__(self, model_id: int, data: dict[str, Any]):
self._properties = data.copy()
self._model_id = model_id
self._model: Optional[Database] = None
self._model: Database | None = None

def run(self) -> Model:
self._model = DatabaseDAO.find_by_id(self._model_id)
Expand Down Expand Up @@ -85,7 +84,7 @@ def run(self) -> Model:
)
database.set_sqlalchemy_uri(database.sqlalchemy_uri)
ssh_tunnel = self._handle_ssh_tunnel(database)
self._refresh_schemas(database, original_database_name, ssh_tunnel)
self._refresh_catalogs(database, original_database_name, ssh_tunnel)
except SSHTunnelError as ex:
# allow exception to bubble for debugbing information
raise ex
Expand Down Expand Up @@ -121,67 +120,200 @@ def _handle_ssh_tunnel(self, database: Database) -> SSHTunnel | None:
ssh_tunnel_properties,
).run()

def _refresh_schemas(
def _get_catalog_names(
self,
database: Database,
original_database_name: str,
ssh_tunnel: Optional[SSHTunnel],
) -> None:
ssh_tunnel: SSHTunnel | None,
) -> set[str]:
"""
Helper method to load catalogs.

This method captures a generic exception, since errors could potentially come
from any of the 50+ database drivers we support.
"""
try:
return database.get_all_catalog_names(
force=True,
ssh_tunnel=ssh_tunnel,
)
except Exception as ex:
db.session.rollback()
raise DatabaseConnectionFailedError() from ex

def _get_schema_names(
self,
database: Database,
catalog: str | None,
ssh_tunnel: SSHTunnel | None,
) -> set[str]:
"""
Add permissions for any new schemas.
Helper method to load schemas.

This method captures a generic exception, since errors could potentially come
from any of the 50+ database drivers we support.
"""
try:
schemas = database.get_all_schema_names(ssh_tunnel=ssh_tunnel)
return database.get_all_schema_names(
force=True,
catalog=catalog,
ssh_tunnel=ssh_tunnel,
)
except Exception as ex:
db.session.rollback()
raise DatabaseConnectionFailedError() from ex

def _refresh_catalogs(
self,
database: Database,
original_database_name: str,
ssh_tunnel: SSHTunnel | None,
) -> None:
"""
Add permissions for any new catalogs and schemas.
"""
catalogs = (
self._get_catalog_names(database, ssh_tunnel)
if database.db_engine_spec.supports_catalog
else [None]
)

for catalog in catalogs:
schemas = self._get_schema_names(database, catalog, ssh_tunnel)

if catalog:
perm = security_manager.get_catalog_perm(
original_database_name,
catalog,
)
existing_pvm = security_manager.find_permission_view_menu(
"catalog_access",
perm,
)
if not existing_pvm:
# new catalog
security_manager.add_permission_view_menu(
"catalog_access",
security_manager.get_catalog_perm(
database.database_name,
catalog,
),
)
for schema in schemas:
security_manager.add_permission_view_menu(
"schema_access",
security_manager.get_schema_perm(
database.database_name,
catalog,
schema,
),
)
continue

# add possible new schemas in catalog
self._refresh_schemas(
database,
original_database_name,
catalog,
schemas,
)

if original_database_name != database.database_name:
self._rename_database_in_permissions(
database,
original_database_name,
catalog,
schemas,
)

db.session.commit()

def _refresh_schemas(
self,
database: Database,
original_database_name: str,
catalog: str | None,
schemas: set[str],
) -> None:
"""
Add new schemas that don't have permissions yet.
"""
for schema in schemas:
original_vm = security_manager.get_schema_perm(
perm = security_manager.get_schema_perm(
original_database_name,
catalog,
schema,
)
existing_pvm = security_manager.find_permission_view_menu(
"schema_access",
original_vm,
perm,
)
if not existing_pvm:
# new schema
security_manager.add_permission_view_menu(
"schema_access",
security_manager.get_schema_perm(database.database_name, schema),
new_name = security_manager.get_schema_perm(
database.database_name,
catalog,
schema,
)
continue
security_manager.add_permission_view_menu("schema_access", new_name)

if original_database_name == database.database_name:
continue
def _rename_database_in_permissions(
self,
database: Database,
original_database_name: str,
catalog: str | None,
schemas: set[str],
) -> None:
new_name = security_manager.get_catalog_perm(
database.database_name,
catalog,
)

# rename existing schema permission
existing_pvm.view_menu.name = security_manager.get_schema_perm(
# rename existing catalog permission
if catalog:
perm = security_manager.get_catalog_perm(
original_database_name,
catalog,
)
existing_pvm = security_manager.find_permission_view_menu(
"catalog_access",
perm,
)
if existing_pvm:
existing_pvm.view_menu.name = new_name

for schema in schemas:
new_name = security_manager.get_schema_perm(
database.database_name,
catalog,
schema,
)

# rename existing schema permission
perm = security_manager.get_schema_perm(
original_database_name,
catalog,
schema,
)
existing_pvm = security_manager.find_permission_view_menu(
"schema_access",
perm,
)
if existing_pvm:
existing_pvm.view_menu.name = new_name

# rename permissions on datasets and charts
for dataset in DatabaseDAO.get_datasets(
database.id,
catalog=None,
catalog=catalog,
schema=schema,
):
dataset.schema_perm = existing_pvm.view_menu.name
dataset.schema_perm = new_name
for chart in DatasetDAO.get_related_objects(dataset.id)["charts"]:
chart.schema_perm = existing_pvm.view_menu.name

db.session.commit()
chart.schema_perm = new_name

def validate(self) -> None:
exceptions: list[ValidationError] = []
database_name: Optional[str] = self._properties.get("database_name")
if database_name:
# Check database_name uniqueness
if database_name := self._properties.get("database_name"):
if not DatabaseDAO.validate_update_uniqueness(
self._model_id, database_name
self._model_id,
database_name,
):
exceptions.append(DatabaseExistsValidationError())
if exceptions:
raise DatabaseInvalidError(exceptions=exceptions)
raise DatabaseInvalidError(exceptions=[DatabaseExistsValidationError()])