背景
我们通过vllm运行大模型的时候,从日志中可以看到经历了很多流程,包括物理设备检测、参数的校验、GPU显存检测、模型检测、模型加载、kv 管理、分布式推理等操作,最终启动了http server服务提供给我们去调用,在这个过程的背后是什么原理,我们尝试通过阅读源码的方式来学习下。
本次我们先看看大模型加载的流程。我们运行大模型的命令行是:
1
vllm serve /var/lib/cache/model_scope/deepseek-ai/DeepSeek-OCR --trust-remote-code --gpu-memory-utilization=0.3 --host 0.0.0.0 --port 40051 --served-model-name deepseek-ocr
使用的vllm版本是v0.11.1
入口函数
vllm的入口文件是vllm/entrypoints/cli/serve.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
class ServeSubcommand(CLISubcommand):
"""The `serve` subcommand for the vLLM CLI."""
name = "serve"
@staticmethod
def cmd(args: argparse.Namespace) -> None:
# If model is specified in CLI (as positional arg), it takes precedence
if hasattr(args, "model_tag") and args.model_tag is not None:
args.model = args.model_tag
# 默认情况下这里的headless是False,api_server_count=1
if args.headless or args.api_server_count < 1:
run_headless(args)
else:
if args.api_server_count > 1:
run_multi_api_server(args)
else:
# 最终走到了run_server中
# Single API server (this process).
uvloop.run(run_server(args))
run_server中根据参数创建socket,重点看下run_server_worker函数
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
async def run_server_worker(
listen_address, sock, args, client_config=None, **uvicorn_kwargs
) -> None:
...
# 大模型的加载就在build_async_engine_client中
async with build_async_engine_client(
args,
client_config=client_config,
) as engine_client:
maybe_register_tokenizer_info_endpoint(args)
app = build_app(args)
await init_app_state(engine_client, app.state, args)
logger.info(
"Starting vLLM API server %d on %s",
engine_client.vllm_config.parallel_config._api_process_rank,
listen_address,
)
# 到了这里模型已经加载完成,可以为外部提供api服务了
shutdown_task = await serve_http(
app,
sock=sock,
enable_ssl_refresh=args.enable_ssl_refresh,
host=args.host,
port=args.port,
log_level=args.uvicorn_log_level,
# NOTE: When the 'disable_uvicorn_access_log' value is True,
# no access log will be output.
access_log=not args.disable_uvicorn_access_log,
timeout_keep_alive=envs.VLLM_HTTP_TIMEOUT_KEEP_ALIVE,
ssl_keyfile=args.ssl_keyfile,
ssl_certfile=args.ssl_certfile,
ssl_ca_certs=args.ssl_ca_certs,
ssl_cert_reqs=args.ssl_cert_reqs,
h11_max_incomplete_event_size=args.h11_max_incomplete_event_size,
h11_max_header_count=args.h11_max_header_count,
**uvicorn_kwargs,
)
# NB: Await server shutdown only after the backend context is exited
try:
await shutdown_task
finally:
sock.close()
build_async_engine_client函数中将外部参数封装成了engine_args类,我们重点看build_async_engine_client_from_engine_args 参数说明:
- usage_context=UsageContext.OPENAI_API_SERVER
- disable_frontend_multiprocessing=False
- client_config=None
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
async def build_async_engine_client_from_engine_args( engine_args: AsyncEngineArgs, *, usage_context: UsageContext = UsageContext.OPENAI_API_SERVER, disable_frontend_multiprocessing: bool = False, client_config: dict[str, Any] | None = None, ) -> AsyncIterator[EngineClient]: ... # 生成vllm配置 vllm_config = engine_args.create_engine_config(usage_context=usage_context) # V1 AsyncLLM. assert envs.VLLM_USE_V1 ... # 生成异步LLM,这里在生成llm的同时还加载了大模型权重文件 try: async_llm = AsyncLLM.from_vllm_config( vllm_config=vllm_config, usage_context=usage_context, enable_log_requests=engine_args.enable_log_requests, aggregate_engine_logging=engine_args.aggregate_engine_logging, disable_log_stats=engine_args.disable_log_stats, client_addresses=client_config, client_count=client_count, client_index=client_index, ) # Don't keep the dummy data in memory await async_llm.reset_mm_cache() yield async_llm finally: if async_llm: async_llm.shutdown()
from_vllm_config的内容很简单,但是要注意Executor.get_class(vllm_config),这里用来获取模型加载器的函数 vllm中将加载大模型的模块叫executor。单机的情况下distributed-executor-backend参数是uni,分布式集群时可以使用ray,单机多卡时使用mp。 本次我们使用单机模式跑的
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
def from_vllm_config(
cls,
vllm_config: VllmConfig,
start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: list[StatLoggerFactory] | None = None,
enable_log_requests: bool = False,
aggregate_engine_logging: bool = False,
disable_log_stats: bool = False,
client_addresses: dict[str, str] | None = None,
client_count: int = 1,
client_index: int = 0,
disable_log_requests: bool = True, # Deprecated, will be removed
) -> "AsyncLLM":
...
# Executor.get_class(vllm_config)加载模型加载类,这里因为是单机运行,获取的模型加载器是UniProcExecutor
return cls(
vllm_config=vllm_config,
executor_class=Executor.get_class(vllm_config),
start_engine_loop=start_engine_loop,
stat_loggers=stat_loggers,
log_requests=enable_log_requests,
log_stats=not disable_log_stats,
aggregate_engine_logging=aggregate_engine_logging,
usage_context=usage_context,
client_addresses=client_addresses,
client_count=client_count,
client_index=client_index,
)
AsyncLLM类初始化函数中会去加载模型文件,这里我们分别从分词器和模型加载器两个部分分别来看看具体的实现
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
class AsyncLLM(EngineClient):
def __init__(
...
) -> None:
...
self.model_config = vllm_config.model_config
self.vllm_config = vllm_config
self.observability_config = vllm_config.observability_config
self.log_requests = log_requests
custom_stat_loggers = list(stat_loggers or [])
custom_stat_loggers.extend(load_stat_logger_plugin_factories())
has_custom_loggers = bool(custom_stat_loggers)
self.log_stats = log_stats or has_custom_loggers
if not log_stats and has_custom_loggers:
logger.info(
"AsyncLLM created with log_stats=False, "
"but custom stat loggers were found; "
"enabling logging without default stat loggers."
)
# 这里默认skip_tokenizer_init是False
if self.model_config.skip_tokenizer_init:
tokenizer = None
else:
# 加载分词器,下面讲
tokenizer = init_tokenizer_from_configs(self.model_config)
# 负责处理模型的输入数据,包括参数验证、文本预处理、多模态数据处理
self.processor = Processor(self.vllm_config, tokenizer)
# 这里默认没有指定io_processor,所以是空的
self.io_processor = get_io_processor(
self.vllm_config,
self.model_config.io_processor_plugin,
)
# OutputProcessor (converts EngineCoreOutputs --> RequestOutput).
self.output_processor = OutputProcessor(
self.tokenizer, log_stats=self.log_stats
)
...
# 开始加载模型文件
self.engine_core = EngineCoreClient.make_async_mp_client(
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=self.log_stats,
client_addresses=client_addresses,
client_count=client_count,
client_index=client_index,
)
...
# output 处理器
self.output_handler: asyncio.Task | None = None
try:
# Start output handler eagerly if we are in the asyncio eventloop.
asyncio.get_running_loop()
self._run_output_handler()
except RuntimeError:
pass
...
else:
self.profiler = None
分词器的初始化是init_tokenizer_from_configs函数中处理的 如果是从网络中下载分词器相关文件,则会排除权重文件的,这里我们直接读取本地模型目录。默认情况下我们使用AutoTokenizer.from_pretrained来加载tokenizer_config.json文件中的tokenizer_class属性
1
2
3
4
5
6
7
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name,
*args,
trust_remote_code=trust_remote_code,
revision=revision,
**kwargs,
)
下面我看看函数make_async_mp_client 参数data_parallel_size=1,因此最终返回的是AsyncMPClient AsyncMPClient这个类的初始本身没有什么,我们需要注意的是父类中对executor_class的处理
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
def make_async_mp_client(
vllm_config: VllmConfig,
executor_class: type[Executor],
log_stats: bool,
client_addresses: dict[str, str] | None = None,
client_count: int = 1,
client_index: int = 0,
) -> "MPClient":
parallel_config = vllm_config.parallel_config
client_args = (
vllm_config,
executor_class,
log_stats,
client_addresses,
client_count,
client_index,
)
if parallel_config.data_parallel_size > 1:
if parallel_config.data_parallel_external_lb:
# External load balancer - client per DP rank.
return DPAsyncMPClient(*client_args)
# Internal load balancer - client balances to all DP ranks.
return DPLBAsyncMPClient(*client_args)
return AsyncMPClient(*client_args)
MPClient的初始化中,参数
- asyncio_mode=True
- executor_class=<class ‘vllm.v1.executor.uniproc_executor.UniProcExecutor’>
- log_stats=True
- client_addresses={}
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110
class MPClient(EngineCoreClient): ... def __init__( self, asyncio_mode: bool, vllm_config: VllmConfig, executor_class: type[Executor], log_stats: bool, client_addresses: dict[str, str] | None = None, ): self.vllm_config = vllm_config # Serialization setup. self.encoder = MsgpackEncoder() self.decoder = MsgpackDecoder(EngineCoreOutputs) # ZMQ setup. sync_ctx = zmq.Context(io_threads=2) self.ctx = zmq.asyncio.Context(sync_ctx) if asyncio_mode else sync_ctx # This will ensure resources created so far are closed # when the client is garbage collected, even if an # exception is raised mid-construction. self.resources = BackgroundResources(ctx=sync_ctx) self._finalizer = weakref.finalize(self, self.resources) success = False try: # State used for data parallel. self.engines_running = False self.stats_update_address: str | None = None if client_addresses: ... else: # Engines are managed by this client. with launch_core_engines(vllm_config, executor_class, log_stats) as ( engine_manager, coordinator, addresses, ): self.resources.coordinator = coordinator self.resources.engine_manager = engine_manager (input_address,) = addresses.inputs (output_address,) = addresses.outputs self.stats_update_address = addresses.frontend_stats_publish_address if coordinator is not None: assert self.stats_update_address == ( coordinator.get_stats_publish_address() ) # Create input and output sockets. self.input_socket = self.resources.input_socket = make_zmq_socket( self.ctx, input_address, zmq.ROUTER, bind=True ) self.resources.output_socket = make_zmq_socket( self.ctx, output_address, zmq.PULL ) parallel_config = vllm_config.parallel_config dp_size = parallel_config.data_parallel_size dp_rank = parallel_config.data_parallel_rank dp_local_size = parallel_config.data_parallel_size_local offline_mode = parallel_config.data_parallel_rank_local is not None # Client manages local+remote EngineCores in pure internal LB case. # Client manages local EngineCores in hybrid and external LB case. local_engines_only = ( parallel_config.data_parallel_hybrid_lb or parallel_config.data_parallel_external_lb ) num_ranks = dp_local_size if local_engines_only else dp_size self.engine_ranks_managed = ( [dp_rank] if offline_mode else list(range(dp_rank, dp_rank + num_ranks)) ) assert parallel_config.data_parallel_size_local <= len( self.engine_ranks_managed ) # ZMQ identity of each engine that this client will talk to. self.core_engines: list[EngineIdentity] = [ rank.to_bytes(2, "little") for rank in self.engine_ranks_managed ] # Wait for ready messages from each engine on the input socket. identities = set(self.core_engines) sync_input_socket = zmq.Socket.shadow(self.input_socket) while identities: if not sync_input_socket.poll(timeout=600_000): raise TimeoutError( "Timed out waiting for engines to send" "initial message on input socket." ) identity, _ = sync_input_socket.recv_multipart() identities.remove(identity) self.core_engine: EngineIdentity = self.core_engines[0] self.utility_results: dict[int, AnyFuture] = {} # Request objects which may contain pytorch-allocated tensors # that we need to keep references to until zmq is done with the # underlying data. self.pending_messages = deque[tuple[zmq.MessageTracker, Any]]() # Start monitoring engine core processes for unexpected failures self.start_engine_core_monitor() success = True finally: if not success: self._finalizer()
由于我们是单机单卡场景,这里的data_parallel_size和data_parallel_size_local都是1 看下来核心代码就在CoreEngineProcManager中了,我们继续往下看
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
def launch_core_engines(
vllm_config: VllmConfig,
executor_class: type[Executor],
log_stats: bool,
num_api_servers: int = 1,
) -> Iterator[
tuple[
CoreEngineProcManager | CoreEngineActorManager | None,
DPCoordinator | None,
EngineZmqAddresses,
]
]:
...
# local_start_index=0因此offline_mode=True
offline_mode = local_start_index is not None
# client_local_only = True
client_local_only = (
offline_mode or local_engines_only or (local_engine_count == dp_size)
)
# 初始化input/output zmq
addresses = EngineZmqAddresses(
inputs=[
get_engine_client_zmq_addr(client_local_only, host)
for _ in range(num_api_servers)
],
outputs=[
get_engine_client_zmq_addr(client_local_only, host)
for _ in range(num_api_servers)
],
)
# 这里由于是单机单卡,run_coordinator=False data_parallel_backend=uni
...
if offline_mode:
assert local_engine_count == 1
engines_to_handshake = [CoreEngine(index=dp_rank, local=True)]
...
# handshake_local_only = True
handshake_local_only = offline_mode or local_engine_count == dp_size
handshake_address = get_engine_client_zmq_addr(
handshake_local_only, host, parallel_config.data_parallel_rpc_port
)
if local_engines_only and dp_rank > 0:
assert not handshake_local_only
local_handshake_address = get_open_zmq_ipc_path()
client_handshake_address = local_handshake_address
else:
local_handshake_address = handshake_address
client_handshake_address = None
with zmq_socket_ctx(
local_handshake_address, zmq.ROUTER, bind=True
) as handshake_socket:
from vllm.v1.engine.core import EngineCoreProc
# 启动本地模型
if local_engine_count:
local_engine_manager = CoreEngineProcManager(
EngineCoreProc.run_engine_core,
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=log_stats,
handshake_address=handshake_address,
client_handshake_address=client_handshake_address,
local_client=True,
local_engine_count=local_engine_count,
start_index=dp_rank,
local_start_index=local_start_index or 0,
)
else:
local_engine_manager = None
yield local_engine_manager, coordinator, addresses
# 等待本地模型起来
wait_for_engine_startup(
handshake_socket,
addresses,
engines_to_handshake,
parallel_config,
vllm_config.cache_config,
local_engine_manager,
coordinator.proc if coordinator else None,
)
CoreEngineProcManager的初始化流程中根据本地卡的数量创建对应多的进程,进程运行的函数是target_fn 这里target_fn参数传的是EngineCoreProc.run_engine_core,继续看run_engine_core函数
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
class CoreEngineProcManager:
def __init__(
self,
target_fn: Callable,
local_engine_count: int,
start_index: int,
local_start_index: int,
vllm_config: VllmConfig,
local_client: bool,
handshake_address: str,
executor_class: type[Executor],
log_stats: bool,
client_handshake_address: str | None = None,
):
context = get_mp_context()
common_kwargs = {
"vllm_config": vllm_config,
"local_client": local_client,
"handshake_address": handshake_address,
"executor_class": executor_class,
"log_stats": log_stats,
}
if client_handshake_address:
common_kwargs["client_handshake_address"] = client_handshake_address
self.processes: list[BaseProcess] = []
local_dp_ranks = []
# 这里是单机单卡,所以local_engine_count=1
for index in range(local_engine_count):
local_index = local_start_index + index
global_index = start_index + index
# Start EngineCore in background process.
local_dp_ranks.append(local_index)
# 这里添加了一个独立运行target_fn函数的进程,注意这里将executor_class传进去了
self.processes.append(
context.Process(
target=target_fn,
name=f"EngineCore_DP{global_index}",
kwargs=common_kwargs
| {
"dp_rank": global_index,
"local_dp_rank": local_index,
},
)
)
self._finalizer = weakref.finalize(self, shutdown, self.processes)
data_parallel = vllm_config.parallel_config.data_parallel_size > 1
try:
for proc, local_dp_rank in zip(self.processes, local_dp_ranks):
# Adjust device control in DP for non-CUDA platforms
# For CUDA platforms, setting same device id for different DP
# processes affects NCCL init performance.
with (
set_device_control_env_var(vllm_config, local_dp_rank)
if (data_parallel and not current_platform.is_cuda_alike())
else contextlib.nullcontext()
):
# 启动进程
proc.start()
finally:
# Kill other procs if not all are running.
if self.finished_procs():
self.close()
run_engine_core中核心的函数是EngineCoreProc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
def run_engine_core(*args, dp_rank: int = 0, local_dp_rank: int = 0, **kwargs):
...
engine_core: EngineCoreProc | None = None
try:
parallel_config: ParallelConfig = kwargs["vllm_config"].parallel_config
if parallel_config.data_parallel_size > 1 or dp_rank > 0:
...
else:
set_process_title("EngineCore")
decorate_logs()
engine_core = EngineCoreProc(*args, **kwargs)
engine_core.run_busy_loop()
...
怎么还没有开始初始化这个executor_class,心累,直接看到底在哪里初始化的吧
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
class EngineCore:
"""Inner loop of vLLM's Engine."""
def __init__(
self,
vllm_config: VllmConfig,
executor_class: type[Executor],
log_stats: bool,
executor_fail_callback: Callable | None = None,
):
...
self.vllm_config = vllm_config
if vllm_config.parallel_config.data_parallel_rank == 0:
logger.info(
"Initializing a V1 LLM engine (v%s) with config: %s",
VLLM_VERSION,
vllm_config,
)
self.log_stats = log_stats
# 来了来了,初始化了。这里executor_class=UniProcExecutor
self.model_executor = executor_class(vllm_config)
if executor_fail_callback is not None:
self.model_executor.register_failure_callback(executor_fail_callback)
self.available_gpu_memory_for_kv_cache = -1
# Setup KV Caches and update CacheConfig after profiling.
num_gpu_blocks, num_cpu_blocks, kv_cache_config = self._initialize_kv_caches(
vllm_config
)
vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks
self.collective_rpc("initialize_cache", args=(num_gpu_blocks, num_cpu_blocks))
self.structured_output_manager = StructuredOutputManager(vllm_config)
# 这里的scheduler_cls=vllm.v1.core.sched.scheduler.Scheduler
if isinstance(vllm_config.scheduler_config.scheduler_cls, str):
Scheduler = resolve_obj_by_qualname(
vllm_config.scheduler_config.scheduler_cls
)
else:
Scheduler = vllm_config.scheduler_config.scheduler_cls
...
if len(kv_cache_config.kv_cache_groups) == 0:
# Encoder models without KV cache don't support
# chunked prefill. But do SSM models?
logger.info("Disabling chunked prefill for model without KVCache")
vllm_config.scheduler_config.chunked_prefill_enabled = False
scheduler_block_size = (
vllm_config.cache_config.block_size
* vllm_config.parallel_config.decode_context_parallel_size
)
# 初始化Scheduler
self.scheduler: SchedulerInterface = Scheduler(
vllm_config=vllm_config,
kv_cache_config=kv_cache_config,
structured_output_manager=self.structured_output_manager,
include_finished_set=vllm_config.parallel_config.data_parallel_size > 1,
log_stats=self.log_stats,
block_size=scheduler_block_size,
)
...
开始加载UniProcExecutor了,重点就是末尾的三个调用
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
class UniProcExecutor(Executor):
def _init_executor(self) -> None:
"""Initialize the worker and load the model."""
self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config, rpc_rank=0)
distributed_init_method, rank, local_rank = self._distributed_args()
kwargs = dict(
vllm_config=self.vllm_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method,
is_driver_worker=True,
shared_worker_lock=Lock(),
)
self.async_output_thread: ThreadPoolExecutor | None = None
if self.max_concurrent_batches > 1:
self.async_output_thread = ThreadPoolExecutor(
max_workers=1, thread_name_prefix="WorkerAsyncOutput"
)
# 加载worker=vllm.v1.worker.gpu_worker.Worker
self.driver_worker.init_worker(all_kwargs=[kwargs])
# 调用vllm.v1.worker.gpu_worker.Worker的init_device函数
# 最主要的是初始化了model_runner
self.driver_worker.init_device()
# 内部调用了self.model_runner.load_model
self.driver_worker.load_model()
下面看看load_model函数
def load_model(self, eep_scale_up: bool = False) -> None:
...
# 当前场景eep_scale_up=False
if eep_scale_up:
...
else:
global_expert_load = None
old_global_expert_indices = None
rank_mapping = None
with DeviceMemoryProfiler() as m:
time_before_load = time.perf_counter()
# 默认使用 "auto": DefaultModelLoader
model_loader = get_model_loader(self.load_config)
self.model = model_loader.load_model(
vllm_config=self.vllm_config, model_config=self.model_config
)
...
# wrap the model with full cudagraph wrapper if needed.
if (
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
and not self.parallel_config.enable_dbo
):
# 最终就封装成了self.model对象
self.model = CUDAGraphWrapper(
self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL
)
elif self.parallel_config.enable_dbo:
...
还得继续看load_model函数
def load_model(
self, vllm_config: VllmConfig, model_config: ModelConfig
) -> nn.Module:
"""Load a model with the given configurations."""
device_config = vllm_config.device_config
load_config = vllm_config.load_config
load_device = (
device_config.device if load_config.device is None else load_config.device
)
target_device = torch.device(load_device)
with set_default_torch_dtype(model_config.dtype):
with target_device:
# 加载模型目录中config.json文件的architectures属性
model = initialize_model(
vllm_config=vllm_config, model_config=model_config
)
logger.debug("Loading weights on %s ...", load_device)
# Quantization does not happen in `load_weights` but after it
self.load_weights(model, model_config)
process_weights_after_loading(model, model_config, target_device)
return model.eval()
下面看看如何加载模型权重文件,重点如下
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
...
# 获取模型中所有需要加载的权重,在实际加载后会对比看看是否有没有加载到的
weights_to_load = {name for name, _ in model.named_parameters()}
...
if model_config.quantization is None:
# model is not quantized
loaded_weights = model.load_weights(
self.get_all_weights(model_config, model)
)
...
疑问
-
多个worker加载模型时,如何分片加载的 接着上面的代码继续看,model.load_weights会调用模型对应的load_weights函数 ``` def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) (“gate_up_proj”, “gate_proj”, 0), (“gate_up_proj”, “up_proj”, 1), (“fused_qkv_a_proj”, “q_a_proj”, 0), (“fused_qkv_a_proj”, “kv_a_proj_with_mqa”, 1), ]
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
... # 这里获取了当前模型维护的named_parameters参数,后面会用来做判断 params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) if spec_layer is not None: continue # skip spec decode layers for main model is_fuse_shared_experts_layer = ( is_rocm_aiter_fusion_shared_expert_enabled() and ("mlp.shared_experts" in name) ) for param_name, weight_name, shard_id in stacked_params_mapping: ... name_mapped = name.replace(weight_name, param_name) # QKV fusion is optional, fall back to normal # weight loading if it's not enabled # if go with fusion option, then update name if ( param_name == "fused_qkv_a_proj" ) and name_mapped not in params_dict: # 如果权重的name不在当前模型的params_dict中,直接不加载 continue else: name = name_mapped # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue if is_pp_missing_parameter(name, self): continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) ...
```
总结
这里我们只看到了模型的加载逻辑,实际还有其他复杂的流程,比如kv cache管理/调度相关/input和output的处理/tokenizer等等,还要继续学习