Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ 适配 nb orm #399

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions nonebot_bison/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from nonebot.plugin import PluginMetadata, require

require("nonebot_plugin_apscheduler")
require("nonebot_plugin_datastore")
require("nonebot_plugin_orm")
require("nonebot_plugin_saa")

import nonebot_plugin_saa

from .config import migrations
from .plugin_config import PlugConfig, plugin_config
from . import post, send, types, utils, config, platform, bootstrap, scheduler, admin_page, sub_manager

Expand All @@ -29,7 +30,11 @@
homepage="https://github.com/felinae98/nonebot-bison",
config=PlugConfig,
supported_adapters=__supported_adapters__,
extra={"version": __help__version__, "docs": "https://nonebot-bison.netlify.app/"},
extra={
"version": __help__version__,
"docs": "https://nonebot-bison.netlify.app/",
"orm_version_location": migrations,
},
)

__all__ = [
Expand Down
36 changes: 3 additions & 33 deletions nonebot_bison/bootstrap.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,14 @@
from nonebot import get_driver
from nonebot.log import logger
from sqlalchemy import text, inspect
from nonebot_plugin_datastore.db import get_engine, pre_db_init, post_db_init

from .config.db_migration import data_migrate
from .scheduler.manager import init_scheduler
from .config.config_legacy import start_up as legacy_db_startup

driver = get_driver()

@pre_db_init
async def pre():
def _has_table(conn, table_name):
insp = inspect(conn)
return insp.has_table(table_name)

async with get_engine().begin() as conn:
if not await conn.run_sync(_has_table, "alembic_version"):
logger.debug("未发现默认版本数据库,开始初始化")
return

logger.debug("发现默认版本数据库,开始检查版本")
t = await conn.scalar(text("select version_num from alembic_version"))
if t not in [
"4a46ba54a3f3", # alter_type
"5f3370328e44", # add_time_weight_table
"0571870f5222", # init_db
"a333d6224193", # add_last_scheduled_time
"c97c445e2bdb", # add_constraint
]:
logger.warning(f"当前数据库版本:{t},不是插件的版本,已跳过。")
return

logger.debug(f"当前数据库版本:{t},是插件的版本,开始迁移。")
# 删除可能存在的版本数据库
if await conn.run_sync(_has_table, "nonebot_bison_alembic_version"):
await conn.execute(text("drop table nonebot_bison_alembic_version"))

await conn.execute(text("alter table alembic_version rename to nonebot_bison_alembic_version"))


@post_db_init
@driver.on_startup
async def post():
# legacy db
legacy_db_startup()
Expand Down
1 change: 1 addition & 0 deletions nonebot_bison/config/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from . import migrations as migrations
from .db_config import config as config
from .utils import NoSuchUserException as NoSuchUserException
from .utils import NoSuchTargetException as NoSuchTargetException
Expand Down
24 changes: 12 additions & 12 deletions nonebot_bison/config/db_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@

from sqlalchemy.orm import selectinload
from sqlalchemy.exc import IntegrityError
from nonebot_plugin_orm import get_session
from sqlalchemy import func, delete, select
from nonebot_plugin_saa import PlatformTarget
from nonebot_plugin_datastore import create_session

from ..types import Tag
from ..types import Target as T_Target
Expand Down Expand Up @@ -46,7 +46,7 @@ async def add_subscribe(
cats: list[Category],
tags: list[Tag],
):
async with create_session() as session:
async with get_session() as session:
db_user_stmt = select(User).where(User.user_target == user.dict())
db_user: User | None = await session.scalar(db_user_stmt)
if not db_user:
Expand Down Expand Up @@ -74,7 +74,7 @@ async def add_subscribe(
raise e

async def list_subscribe(self, user: PlatformTarget) -> Sequence[Subscribe]:
async with create_session() as session:
async with get_session() as session:
query_stmt = (
select(Subscribe)
.where(User.user_target == user.dict())
Expand All @@ -86,7 +86,7 @@ async def list_subscribe(self, user: PlatformTarget) -> Sequence[Subscribe]:

async def list_subs_with_all_info(self) -> Sequence[Subscribe]:
"""获取数据库中带有user、target信息的subscribe数据"""
async with create_session() as session:
async with get_session() as session:
query_stmt = (
select(Subscribe).join(User).options(selectinload(Subscribe.target), selectinload(Subscribe.user))
)
Expand All @@ -95,7 +95,7 @@ async def list_subs_with_all_info(self) -> Sequence[Subscribe]:
return subs

async def del_subscribe(self, user: PlatformTarget, target: str, platform_name: str):
async with create_session() as session:
async with get_session() as session:
user_obj = await session.scalar(select(User).where(User.user_target == user.dict()))
target_obj = await session.scalar(
select(Target).where(Target.platform_name == platform_name, Target.target == target)
Expand All @@ -118,7 +118,7 @@ async def update_subscribe(
cats: list,
tags: list,
):
async with create_session() as sess:
async with get_session() as sess:
subscribe_obj: Subscribe = await sess.scalar(
select(Subscribe)
.where(
Expand All @@ -136,13 +136,13 @@ async def update_subscribe(
await sess.commit()

async def get_platform_target(self, platform_name: str) -> Sequence[Target]:
async with create_session() as sess:
async with get_session() as sess:
subq = select(Subscribe.target_id).distinct().subquery()
query = select(Target).join(subq).where(Target.platform_name == platform_name)
return (await sess.scalars(query)).all()

async def get_time_weight_config(self, target: T_Target, platform_name: str) -> WeightConfig:
async with create_session() as sess:
async with get_session() as sess:
time_weight_conf = (
await sess.scalars(
select(ScheduleTimeWeight)
Expand All @@ -167,7 +167,7 @@ async def get_time_weight_config(self, target: T_Target, platform_name: str) ->
)

async def update_time_weight_config(self, target: T_Target, platform_name: str, conf: WeightConfig):
async with create_session() as sess:
async with get_session() as sess:
targetObj = await sess.scalar(
select(Target).where(Target.platform_name == platform_name, Target.target == target)
)
Expand All @@ -191,7 +191,7 @@ async def update_time_weight_config(self, target: T_Target, platform_name: str,
async def get_current_weight_val(self, platform_list: list[str]) -> dict[str, int]:
res = {}
cur_time = _get_time()
async with create_session() as sess:
async with get_session() as sess:
targets = (
await sess.scalars(
select(Target)
Expand All @@ -210,7 +210,7 @@ async def get_current_weight_val(self, platform_list: list[str]) -> dict[str, in
return res

async def get_platform_target_subscribers(self, platform_name: str, target: T_Target) -> list[UserSubInfo]:
async with create_session() as sess:
async with get_session() as sess:
query = (
select(Subscribe)
.join(Target)
Expand All @@ -231,7 +231,7 @@ async def get_all_weight_config(
self,
) -> dict[str, dict[str, PlatformWeightConfigResp]]:
res: dict[str, dict[str, PlatformWeightConfigResp]] = defaultdict(dict)
async with create_session() as sess:
async with get_session() as sess:
query = select(Target)
targets = (await sess.scalars(query)).all()
query = select(ScheduleTimeWeight).options(selectinload(ScheduleTimeWeight.target))
Expand Down
5 changes: 2 additions & 3 deletions nonebot_bison/config/db_migration.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from nonebot.log import logger
from nonebot_plugin_datastore.db import get_engine
from sqlalchemy.ext.asyncio.session import AsyncSession
from nonebot_plugin_orm import get_session
from nonebot_plugin_saa import TargetQQGroup, TargetQQPrivate

from .db_model import User, Target, Subscribe
Expand All @@ -12,7 +11,7 @@ async def data_migrate():
if config.available:
logger.warning("You are still using legacy db, migrating to sqlite")
all_subs: list[ConfigContent] = [ConfigContent(**item) for item in config.get_all_subscribe().all()]
async with AsyncSession(get_engine()) as sess:
async with get_session() as sess:
user_to_create = []
subscribe_to_create = []
platform_target_map: dict[str, tuple[Target, str, int]] = {}
Expand Down
6 changes: 1 addition & 5 deletions nonebot_bison/config/db_model.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
import datetime
from pathlib import Path

from nonebot_plugin_orm import Model
from sqlalchemy.dialects.postgresql import JSONB
from nonebot_plugin_saa.utils import PlatformTarget
from nonebot_plugin_datastore import get_plugin_data
from sqlalchemy.orm import Mapped, relationship, mapped_column
from sqlalchemy import JSON, String, ForeignKey, UniqueConstraint

from ..types import Tag, Category

Model = get_plugin_data().Model
get_plugin_data().set_migration_dir(Path(__file__).parent / "migrations")


class User(Model):
id: Mapped[int] = mapped_column(primary_key=True)
Expand Down
60 changes: 0 additions & 60 deletions nonebot_bison/config/migrations/0571870f5222_init_db.py

This file was deleted.

85 changes: 85 additions & 0 deletions nonebot_bison/config/migrations/0f2a5973c8ae_init_db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""init db

修订 ID: 0f2a5973c8ae
父修订:
创建时间: 2023-10-12 19:23:28.933609

"""
from __future__ import annotations

from collections.abc import Sequence

import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql

revision: str = "0f2a5973c8ae"
down_revision: str | Sequence[str] | None = None
branch_labels: str | Sequence[str] | None = ("nonebot_bison",)
depends_on: str | Sequence[str] | None = None


def upgrade(name: str = "") -> None:
if name:
return
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"nonebot_bison_target",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("platform_name", sa.String(length=20), nullable=False),
sa.Column("target", sa.String(length=1024), nullable=False),
sa.Column("target_name", sa.String(length=1024), nullable=False),
sa.Column("default_schedule_weight", sa.Integer(), nullable=False),
sa.PrimaryKeyConstraint("id", name=op.f("pk_nonebot_bison_target")),
sa.UniqueConstraint("target", "platform_name", name="unique-target-constraint"),
)
op.create_table(
"nonebot_bison_user",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("user_target", sa.JSON().with_variant(postgresql.JSONB(), "postgresql"), nullable=False),
sa.PrimaryKeyConstraint("id", name=op.f("pk_nonebot_bison_user")),
)
op.create_table(
"nonebot_bison_scheduletimeweight",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("target_id", sa.Integer(), nullable=False),
sa.Column("start_time", sa.Time(), nullable=False),
sa.Column("end_time", sa.Time(), nullable=False),
sa.Column("weight", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(
["target_id"],
["nonebot_bison_target.id"],
name=op.f("fk_nonebot_bison_scheduletimeweight_target_id_nonebot_bison_target"),
),
sa.PrimaryKeyConstraint("id", name=op.f("pk_nonebot_bison_scheduletimeweight")),
)
op.create_table(
"nonebot_bison_subscribe",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("target_id", sa.Integer(), nullable=False),
sa.Column("user_id", sa.Integer(), nullable=False),
sa.Column("categories", sa.JSON(), nullable=False),
sa.Column("tags", sa.JSON(), nullable=False),
sa.ForeignKeyConstraint(
["target_id"],
["nonebot_bison_target.id"],
name=op.f("fk_nonebot_bison_subscribe_target_id_nonebot_bison_target"),
),
sa.ForeignKeyConstraint(
["user_id"], ["nonebot_bison_user.id"], name=op.f("fk_nonebot_bison_subscribe_user_id_nonebot_bison_user")
),
sa.PrimaryKeyConstraint("id", name=op.f("pk_nonebot_bison_subscribe")),
sa.UniqueConstraint("target_id", "user_id", name="unique-subscribe-constraint"),
)
# ### end Alembic commands ###


def downgrade(name: str = "") -> None:
if name:
return
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("nonebot_bison_subscribe")
op.drop_table("nonebot_bison_scheduletimeweight")
op.drop_table("nonebot_bison_user")
op.drop_table("nonebot_bison_target")
# ### end Alembic commands ###
Loading
Loading