Compare commits
6 Commits
openapi-sp
...
venv-manag
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c3f48337ae | ||
|
|
ded60c33a0 | ||
|
|
8bb858e4d3 | ||
|
|
57893c843f | ||
|
|
65da29aaa9 | ||
|
|
10024a38ea |
125
app/venv_management.py
Normal file
125
app/venv_management.py
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
import torch
|
||||||
|
import torchvision
|
||||||
|
import torchaudio
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
if importlib.util.find_spec("torch_directml"):
|
||||||
|
from pip._vendor import pkg_resources
|
||||||
|
|
||||||
|
|
||||||
|
class VEnvException(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TorchVersionInfo:
|
||||||
|
name: str = None
|
||||||
|
version: str = None
|
||||||
|
extension: str = None
|
||||||
|
is_nightly: bool = False
|
||||||
|
is_cpu: bool = False
|
||||||
|
is_cuda: bool = False
|
||||||
|
is_xpu: bool = False
|
||||||
|
is_rocm: bool = False
|
||||||
|
is_directml: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
def get_bootstrap_requirements_string():
|
||||||
|
'''
|
||||||
|
Get string to insert into a 'pip install' command to get the same torch dependencies as current venv.
|
||||||
|
'''
|
||||||
|
torch_info = get_torch_info(torch)
|
||||||
|
packages = [torchvision, torchaudio]
|
||||||
|
infos = [torch_info] + [get_torch_info(x) for x in packages]
|
||||||
|
# directml should be first dependency, if exists
|
||||||
|
directml_info = get_torch_directml_info()
|
||||||
|
if directml_info is not None:
|
||||||
|
infos = [directml_info] + infos
|
||||||
|
# create list of strings to combine into install string
|
||||||
|
install_str_list = []
|
||||||
|
for info in infos:
|
||||||
|
info_string = f"{info.name}=={info.version}"
|
||||||
|
if not info.is_cpu and not info.is_directml:
|
||||||
|
info_string = f"{info_string}+{info.extension}"
|
||||||
|
install_str_list.append(info_string)
|
||||||
|
# handle extra_index_url, if needed
|
||||||
|
extra_index_url = get_index_url(torch_info)
|
||||||
|
if extra_index_url:
|
||||||
|
install_str_list.append(extra_index_url)
|
||||||
|
# format nightly install properly
|
||||||
|
if torch_info.is_nightly:
|
||||||
|
install_str_list = ["--pre"] + install_str_list
|
||||||
|
|
||||||
|
install_str = " ".join(install_str_list)
|
||||||
|
return install_str
|
||||||
|
|
||||||
|
def get_index_url(info: TorchVersionInfo=None):
|
||||||
|
'''
|
||||||
|
Get --extra-index-url (or --index-url) for torch install.
|
||||||
|
'''
|
||||||
|
if info is None:
|
||||||
|
info = get_torch_info()
|
||||||
|
# for cpu, don't need any index_url
|
||||||
|
if info.is_cpu and not info.is_nightly:
|
||||||
|
return None
|
||||||
|
# otherwise, format index_url
|
||||||
|
base_url = "https://download.pytorch.org/whl/"
|
||||||
|
if info.is_nightly:
|
||||||
|
base_url = f"--index-url {base_url}nightly/"
|
||||||
|
else:
|
||||||
|
base_url = f"--extra-index-url {base_url}"
|
||||||
|
base_url = f"{base_url}{info.extension}"
|
||||||
|
return base_url
|
||||||
|
|
||||||
|
def get_torch_info(package=None):
|
||||||
|
'''
|
||||||
|
Get info about an installed torch-related package.
|
||||||
|
'''
|
||||||
|
if package is None:
|
||||||
|
package = torch
|
||||||
|
info = TorchVersionInfo(name=package.__name__)
|
||||||
|
info.version = package.__version__
|
||||||
|
info.extension = None
|
||||||
|
info.is_nightly = False
|
||||||
|
# get extension, separate from version
|
||||||
|
info.version, info.extension = info.version.split('+', 1)
|
||||||
|
if info.extension.startswith('cpu'):
|
||||||
|
info.is_cpu = True
|
||||||
|
elif info.extension.startswith('cu'):
|
||||||
|
info.is_cuda = True
|
||||||
|
elif info.extension.startswith('rocm'):
|
||||||
|
info.is_rocm = True
|
||||||
|
elif info.extension.startswith('xpu'):
|
||||||
|
info.is_xpu = True
|
||||||
|
# TODO: add checks for some odd pytorch versions, if possible
|
||||||
|
|
||||||
|
# check if nightly install
|
||||||
|
if 'dev' in info.version:
|
||||||
|
info.is_nightly = True
|
||||||
|
|
||||||
|
return info
|
||||||
|
|
||||||
|
def get_torch_directml_info():
|
||||||
|
'''
|
||||||
|
Get info specifically about torch-directml package.
|
||||||
|
|
||||||
|
Returns None if torch-directml is not installed.
|
||||||
|
'''
|
||||||
|
# the import string and the pip string are different
|
||||||
|
pip_name = "torch-directml"
|
||||||
|
# if no torch_directml, do nothing
|
||||||
|
if not importlib.util.find_spec("torch_directml"):
|
||||||
|
return None
|
||||||
|
info = TorchVersionInfo(name=pip_name)
|
||||||
|
info.is_directml = True
|
||||||
|
for p in pkg_resources.working_set:
|
||||||
|
if p.project_name.lower() == pip_name:
|
||||||
|
info.version = p.version
|
||||||
|
if p.version is None:
|
||||||
|
return None
|
||||||
|
return info
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
print(get_bootstrap_requirements_string())
|
||||||
5
comfy_api/torch_helpers/__init__.py
Normal file
5
comfy_api/torch_helpers/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
from .torch_compile import set_torch_compile_wrapper
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"set_torch_compile_wrapper",
|
||||||
|
]
|
||||||
69
comfy_api/torch_helpers/torch_compile.py
Normal file
69
comfy_api/torch_helpers/torch_compile.py
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import comfy.utils
|
||||||
|
from comfy.patcher_extension import WrappersMP
|
||||||
|
from typing import TYPE_CHECKING, Callable, Optional
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from comfy.model_patcher import ModelPatcher
|
||||||
|
from comfy.patcher_extension import WrapperExecutor
|
||||||
|
|
||||||
|
|
||||||
|
COMPILE_KEY = "torch.compile"
|
||||||
|
TORCH_COMPILE_KWARGS = "torch_compile_kwargs"
|
||||||
|
|
||||||
|
|
||||||
|
def apply_torch_compile_factory(compiled_module_dict: dict[str, Callable]) -> Callable:
|
||||||
|
'''
|
||||||
|
Create a wrapper that will refer to the compiled_diffusion_model.
|
||||||
|
'''
|
||||||
|
def apply_torch_compile_wrapper(executor: WrapperExecutor, *args, **kwargs):
|
||||||
|
try:
|
||||||
|
orig_modules = {}
|
||||||
|
for key, value in compiled_module_dict.items():
|
||||||
|
orig_modules[key] = comfy.utils.get_attr(executor.class_obj, key)
|
||||||
|
comfy.utils.set_attr(executor.class_obj, key, value)
|
||||||
|
return executor(*args, **kwargs)
|
||||||
|
finally:
|
||||||
|
for key, value in orig_modules.items():
|
||||||
|
comfy.utils.set_attr(executor.class_obj, key, value)
|
||||||
|
return apply_torch_compile_wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def set_torch_compile_wrapper(model: ModelPatcher, backend: str, options: Optional[dict[str,str]]=None,
|
||||||
|
mode: Optional[str]=None, fullgraph=False, dynamic: Optional[bool]=None,
|
||||||
|
keys: list[str]=["diffusion_model"], *args, **kwargs):
|
||||||
|
'''
|
||||||
|
Perform torch.compile that will be applied at sample time for either the whole model or specific params of the BaseModel instance.
|
||||||
|
|
||||||
|
When keys is None, it will default to using ["diffusion_model"], compiling the whole diffusion_model.
|
||||||
|
When a list of keys is provided, it will perform torch.compile on only the selected modules.
|
||||||
|
'''
|
||||||
|
# clear out any other torch.compile wrappers
|
||||||
|
model.remove_wrappers_with_key(WrappersMP.APPLY_MODEL, COMPILE_KEY)
|
||||||
|
# if no keys, default to 'diffusion_model'
|
||||||
|
if not keys:
|
||||||
|
keys = ["diffusion_model"]
|
||||||
|
# create kwargs dict that can be referenced later
|
||||||
|
compile_kwargs = {
|
||||||
|
"backend": backend,
|
||||||
|
"options": options,
|
||||||
|
"mode": mode,
|
||||||
|
"fullgraph": fullgraph,
|
||||||
|
"dynamic": dynamic,
|
||||||
|
}
|
||||||
|
# get a dict of compiled keys
|
||||||
|
compiled_modules = {}
|
||||||
|
for key in keys:
|
||||||
|
compiled_modules[key] = torch.compile(
|
||||||
|
model=model.get_model_object(key),
|
||||||
|
**compile_kwargs,
|
||||||
|
)
|
||||||
|
# add torch.compile wrapper
|
||||||
|
wrapper_func = apply_torch_compile_factory(
|
||||||
|
compiled_module_dict=compiled_modules,
|
||||||
|
)
|
||||||
|
# store wrapper to run on BaseModel's apply_model function
|
||||||
|
model.add_wrapper_with_key(WrappersMP.APPLY_MODEL, COMPILE_KEY, wrapper_func)
|
||||||
|
# keep compile kwargs for reference
|
||||||
|
model.model_options[TORCH_COMPILE_KWARGS] = compile_kwargs
|
||||||
@@ -1,4 +1,5 @@
|
|||||||
import torch
|
from comfy_api.torch_helpers import set_torch_compile_wrapper
|
||||||
|
|
||||||
|
|
||||||
class TorchCompileModel:
|
class TorchCompileModel:
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -14,7 +15,7 @@ class TorchCompileModel:
|
|||||||
|
|
||||||
def patch(self, model, backend):
|
def patch(self, model, backend):
|
||||||
m = model.clone()
|
m = model.clone()
|
||||||
m.add_object_patch("diffusion_model", torch.compile(model=m.get_model_object("diffusion_model"), backend=backend))
|
set_torch_compile_wrapper(model=m, backend=backend)
|
||||||
return (m, )
|
return (m, )
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
# This file is automatically generated by the build process when version is
|
# This file is automatically generated by the build process when version is
|
||||||
# updated in pyproject.toml.
|
# updated in pyproject.toml.
|
||||||
__version__ = "0.3.34"
|
__version__ = "0.3.35"
|
||||||
|
|||||||
@@ -909,7 +909,6 @@ class PromptQueue:
|
|||||||
self.currently_running = {}
|
self.currently_running = {}
|
||||||
self.history = {}
|
self.history = {}
|
||||||
self.flags = {}
|
self.flags = {}
|
||||||
server.prompt_queue = self
|
|
||||||
|
|
||||||
def put(self, item):
|
def put(self, item):
|
||||||
with self.mutex:
|
with self.mutex:
|
||||||
@@ -954,6 +953,7 @@ class PromptQueue:
|
|||||||
self.history[prompt[1]].update(history_result)
|
self.history[prompt[1]].update(history_result)
|
||||||
self.server.queue_updated()
|
self.server.queue_updated()
|
||||||
|
|
||||||
|
# Note: slow
|
||||||
def get_current_queue(self):
|
def get_current_queue(self):
|
||||||
with self.mutex:
|
with self.mutex:
|
||||||
out = []
|
out = []
|
||||||
@@ -961,6 +961,13 @@ class PromptQueue:
|
|||||||
out += [x]
|
out += [x]
|
||||||
return (out, copy.deepcopy(self.queue))
|
return (out, copy.deepcopy(self.queue))
|
||||||
|
|
||||||
|
# read-safe as long as queue items are immutable
|
||||||
|
def get_current_queue_volatile(self):
|
||||||
|
with self.mutex:
|
||||||
|
running = [x for x in self.currently_running.values()]
|
||||||
|
queued = copy.copy(self.queue)
|
||||||
|
return (running, queued)
|
||||||
|
|
||||||
def get_tasks_remaining(self):
|
def get_tasks_remaining(self):
|
||||||
with self.mutex:
|
with self.mutex:
|
||||||
return len(self.queue) + len(self.currently_running)
|
return len(self.queue) + len(self.currently_running)
|
||||||
|
|||||||
3
main.py
3
main.py
@@ -260,7 +260,6 @@ def start_comfyui(asyncio_loop=None):
|
|||||||
asyncio_loop = asyncio.new_event_loop()
|
asyncio_loop = asyncio.new_event_loop()
|
||||||
asyncio.set_event_loop(asyncio_loop)
|
asyncio.set_event_loop(asyncio_loop)
|
||||||
prompt_server = server.PromptServer(asyncio_loop)
|
prompt_server = server.PromptServer(asyncio_loop)
|
||||||
q = execution.PromptQueue(prompt_server)
|
|
||||||
|
|
||||||
hook_breaker_ac10a0.save_functions()
|
hook_breaker_ac10a0.save_functions()
|
||||||
nodes.init_extra_nodes(init_custom_nodes=not args.disable_all_custom_nodes, init_api_nodes=not args.disable_api_nodes)
|
nodes.init_extra_nodes(init_custom_nodes=not args.disable_all_custom_nodes, init_api_nodes=not args.disable_api_nodes)
|
||||||
@@ -271,7 +270,7 @@ def start_comfyui(asyncio_loop=None):
|
|||||||
prompt_server.add_routes()
|
prompt_server.add_routes()
|
||||||
hijack_progress(prompt_server)
|
hijack_progress(prompt_server)
|
||||||
|
|
||||||
threading.Thread(target=prompt_worker, daemon=True, args=(q, prompt_server,)).start()
|
threading.Thread(target=prompt_worker, daemon=True, args=(prompt_server.prompt_queue, prompt_server,)).start()
|
||||||
|
|
||||||
if args.quick_test_for_ci:
|
if args.quick_test_for_ci:
|
||||||
exit(0)
|
exit(0)
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "ComfyUI"
|
name = "ComfyUI"
|
||||||
version = "0.3.34"
|
version = "0.3.35"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = { file = "LICENSE" }
|
license = { file = "LICENSE" }
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.9"
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
comfyui-frontend-package==1.19.9
|
comfyui-frontend-package==1.19.9
|
||||||
comfyui-workflow-templates==0.1.14
|
comfyui-workflow-templates==0.1.18
|
||||||
torch
|
torch
|
||||||
torchsde
|
torchsde
|
||||||
torchvision
|
torchvision
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ import comfy.model_management
|
|||||||
import node_helpers
|
import node_helpers
|
||||||
from comfyui_version import __version__
|
from comfyui_version import __version__
|
||||||
from app.frontend_management import FrontendManager
|
from app.frontend_management import FrontendManager
|
||||||
|
|
||||||
from app.user_manager import UserManager
|
from app.user_manager import UserManager
|
||||||
from app.model_manager import ModelFileManager
|
from app.model_manager import ModelFileManager
|
||||||
from app.custom_node_manager import CustomNodeManager
|
from app.custom_node_manager import CustomNodeManager
|
||||||
@@ -159,7 +160,7 @@ class PromptServer():
|
|||||||
self.custom_node_manager = CustomNodeManager()
|
self.custom_node_manager = CustomNodeManager()
|
||||||
self.internal_routes = InternalRoutes(self)
|
self.internal_routes = InternalRoutes(self)
|
||||||
self.supports = ["custom_nodes_from_web"]
|
self.supports = ["custom_nodes_from_web"]
|
||||||
self.prompt_queue = None
|
self.prompt_queue = execution.PromptQueue(self)
|
||||||
self.loop = loop
|
self.loop = loop
|
||||||
self.messages = asyncio.Queue()
|
self.messages = asyncio.Queue()
|
||||||
self.client_session:Optional[aiohttp.ClientSession] = None
|
self.client_session:Optional[aiohttp.ClientSession] = None
|
||||||
@@ -226,7 +227,7 @@ class PromptServer():
|
|||||||
return response
|
return response
|
||||||
|
|
||||||
@routes.get("/embeddings")
|
@routes.get("/embeddings")
|
||||||
def get_embeddings(self):
|
def get_embeddings(request):
|
||||||
embeddings = folder_paths.get_filename_list("embeddings")
|
embeddings = folder_paths.get_filename_list("embeddings")
|
||||||
return web.json_response(list(map(lambda a: os.path.splitext(a)[0], embeddings)))
|
return web.json_response(list(map(lambda a: os.path.splitext(a)[0], embeddings)))
|
||||||
|
|
||||||
@@ -282,7 +283,6 @@ class PromptServer():
|
|||||||
a.update(f.read())
|
a.update(f.read())
|
||||||
b.update(image.file.read())
|
b.update(image.file.read())
|
||||||
image.file.seek(0)
|
image.file.seek(0)
|
||||||
f.close()
|
|
||||||
return a.hexdigest() == b.hexdigest()
|
return a.hexdigest() == b.hexdigest()
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@@ -621,7 +621,7 @@ class PromptServer():
|
|||||||
@routes.get("/queue")
|
@routes.get("/queue")
|
||||||
async def get_queue(request):
|
async def get_queue(request):
|
||||||
queue_info = {}
|
queue_info = {}
|
||||||
current_queue = self.prompt_queue.get_current_queue()
|
current_queue = self.prompt_queue.get_current_queue_volatile()
|
||||||
queue_info['queue_running'] = current_queue[0]
|
queue_info['queue_running'] = current_queue[0]
|
||||||
queue_info['queue_pending'] = current_queue[1]
|
queue_info['queue_pending'] = current_queue[1]
|
||||||
return web.json_response(queue_info)
|
return web.json_response(queue_info)
|
||||||
|
|||||||
Reference in New Issue
Block a user