Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

API多人调用的时候如何保持自己的模型weights不被别人覆盖? #1037

Open
KevinZhang19870314 opened this issue Apr 30, 2024 · 8 comments

Comments

@KevinZhang19870314
Copy link

KevinZhang19870314 commented Apr 30, 2024

self.vits_model = self.vits_model.half()

条件:用户A调用了set_sovits_weights设置了model a,然后还没有调用api进行推理的时候,用户B同时也调用了set_sovits_weights设置了model b。

问题:这时用户A和用户B同时调用推理api,请问这时用户A和用户B是否是使用用户B设置过的model b?这个是否可以改进?如何能够让多用户调用不同的model weights时不被别人覆盖?

@KevinZhang19870314
Copy link
Author

KevinZhang19870314 commented Apr 30, 2024

我这里想到的有2个方案,下面分别举一个示例代码:

  1. 使用类似模型池model pool的概念(这里可以限制一下pool size):
from fastapi import FastAPI
from threading import Lock

app = FastAPI()
model_pool = {}
model_pool_lock = Lock()

@app.post("/tts")
async def tts(text: str, weights_path: str):
    # 从模型池中获取模型实例
    t2s_model = get_model_instance(weights_path)
    audio = xx.tts(text, t2s_model)
    return {"audio": audio}

def get_model_instance(weights_path):
    # 使用锁来确保线程安全
    with model_pool_lock:
        if weights_path not in model_pool:
            t2s_model = init_t2s_weights(weights_path)
            model_pool[weights_path] = t2s_model
        else:
            t2s_model = model_pool[weights_path]
    return t2s_model

def init_t2s_weights(weights_path: str):
    # ...
    return t2s_model
  1. 使用functools.lru_cache缓存model结果:
from functools import lru_cache

# 假设模型池大小为10
@lru_cache(maxsize=10)
def get_model_instance(weights_path: str):
    # lru_cache缓存,因此当缓存满了后,最少使用的实例将被移除
    t2s_model = init_t2s_weights(weights_path)
    return t2s_model

def init_t2s_weights(weights_path: str):
    # ...
    return t2s_model

@app.post("/tts")
async def tts(text: str, weights_path: str):
    t2s_model = get_model_instance(weights_path)
    audio = xxx.tts(text, t2s_model )
    return {"audio": audio}

@RVC-Boss
Copy link
Owner

RVC-Boss commented May 2, 2024

@ChasonJiang

@ChasonJiang
Copy link

我这里想到的有2个方案,下面分别举一个示例代码:

这是一个关于并发的问题,可以按照你这两个方法基本思路来做。把t2s_weights换成TTS类实例就行,也就是说创建一个TTS类的实例的pool(TTS instance pool)。
不过需要注意的是,你需要以task的视角去做,也就是说,每一个tts请求是一个task,每个task将占有一个TTS类的实例,然后根据你服务器的性能决定你能并发运行多少个TTS类的实例。
还有,既然规定了能够并发的上限。也就需要为task设置一个队列(task_queue),同时需要使用一个简单的调度器,用来调度task和管理TTS_instance。
总之,实现起来还挺麻烦的。

@KevinZhang19870314
Copy link
Author

我这里想到的有2个方案,下面分别举一个示例代码:

这是一个关于并发的问题,可以按照你这两个方法基本思路来做。把t2s_weights换成TTS类实例就行,也就是说创建一个TTS类的实例的pool(TTS instance pool)。 不过需要注意的是,你需要以task的视角去做,也就是说,每一个tts请求是一个task,每个task将占有一个TTS类的实例,然后根据你服务器的性能决定你能并发运行多少个TTS类的实例。 还有,既然规定了能够并发的上限。也就需要为task设置一个队列(task_queue),同时需要使用一个简单的调度器,用来调度task和管理TTS_instance。 总之,实现起来还挺麻烦的。

看了你的回答,我觉得与这个项目的实现比较像(我之前提的一个PR),麻烦您抽时间看看是否类似?feat: add support for maximum concurrency of /api/v1/videos

@ChasonJiang
Copy link

看了你的回答,我觉得与这个项目的实现比较像(我之前提的一个PR),麻烦您抽时间看看是否类似?feat: add support for maximum concurrency of /api/v1/videos

嗯嗯,是类似的。不过要注意一个task应占有一个TTS类的实例,不然会导致生成的音频发生混乱。

@ZhangJianBeiJing
Copy link

我这里想到的有2个方案,下面分别举一个示例代码:

这是一个关于并发的问题,可以按照你这两个方法基本思路来做。把t2s_weights换成TTS类实例就行,也就是说创建一个TTS类的实例的pool(TTS instance pool)。 不过需要注意的是,你需要以task的视角去做,也就是说,每一个tts请求是一个task,每个task将占有一个TTS类的实例,然后根据你服务器的性能决定你能并发运行多少个TTS类的实例。 还有,既然规定了能够并发的上限。也就需要为task设置一个队列(task_queue),同时需要使用一个简单的调度器,用来调度task和管理TTS_instance。 总之,实现起来还挺麻烦的。

看了你的回答,我觉得与这个项目的实现比较像(我之前提的一个PR),麻烦您抽时间看看是否类似?feat: add support for maximum concurrency of /api/v1/videos

mark

@KevinZhang19870314
Copy link
Author

KevinZhang19870314 commented May 7, 2024

看了你的回答,我觉得与这个项目的实现比较像(我之前提的一个PR),麻烦您抽时间看看是否类似?feat: add support for maximum concurrency of /api/v1/videos

嗯嗯,是类似的。不过要注意一个task应占有一个TTS类的实例,不然会导致生成的音频发生混乱。

tts_pipeline = TTS(tts_config)

本地跑了一下,这一句执行时间以秒为单位的(差不多3~5秒,一般的办公笔记本,CPU),如果一个task用一个TTS实例,感觉还是比较费时间的。感觉还是多开比较好(使用类似supervisor),一个声音一个api服务,初始化的时候加载进内存。

如果执意要这么做,初始化的时候得维护一个tts的dict,后面每次调用去这个dict拿,这样应该是可以的。那么这样的话就可以用到functools.lru_cache get_tts_instance(voice_name: str)来维护这个dict了。

@KevinZhang19870314
Copy link
Author

KevinZhang19870314 commented May 7, 2024

暂时使用lru cache实现了一下,后面有空结合lru cache和feat: add support for maximum concurrency of /api/v1/videos再重构一下(如果需要的话)。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants