diff --git a/src/agents/Dockerfile b/src/agents/Dockerfile new file mode 100644 index 00000000..dcd2ebf2 --- /dev/null +++ b/src/agents/Dockerfile @@ -0,0 +1,24 @@ +FROM python:3.13-slim AS base + +FROM base AS builder + +WORKDIR /builder + +COPY pyproject.toml . + +RUN mkdir /install && \ + pip install --prefix=/install . + +FROM base AS production + +WORKDIR /app + +ARG DOCKER_USER +USER ${DOCKER_USER} + +# Un-privileged user running the application +COPY --from=builder /install /usr/local + +COPY . . + +CMD ["python", "multi-user-transcriber.py", "start"] diff --git a/src/agents/multi-user-transcriber.py b/src/agents/multi-user-transcriber.py new file mode 100644 index 00000000..16fb78ff --- /dev/null +++ b/src/agents/multi-user-transcriber.py @@ -0,0 +1,189 @@ +"""Multi user transcription agent.""" + +import asyncio +import logging +import os + +from dotenv import load_dotenv +from livekit import api, rtc +from livekit.agents import ( + Agent, + AgentSession, + AutoSubscribe, + JobContext, + JobProcess, + JobRequest, + RoomInputOptions, + RoomIO, + RoomOutputOptions, + WorkerOptions, + WorkerPermissions, + cli, + utils, +) +from livekit.plugins import deepgram, silero + +load_dotenv() + +logger = logging.getLogger("transcriber") + +TRANSCRIBER_AGENT_NAME = os.getenv("TRANSCRIBER_AGENT_NAME", "multi-user-transcriber") + + +class Transcriber(Agent): + """Create a transcription agent for a specific participant.""" + + def __init__(self, *, participant_identity: str): + """Init transcription agent.""" + super().__init__( + instructions="not-needed", + stt=deepgram.STT(), + ) + self.participant_identity = participant_identity + + +class MultiUserTranscriber: + """Manage transcription sessions for multiple room participants.""" + + def __init__(self, ctx: JobContext): + """Init multi user transcription agent.""" + self.ctx = ctx + self._sessions: dict[str, AgentSession] = {} + self._tasks: set[asyncio.Task] = set() + + def start(self): + """Start listening for participant connection events.""" + self.ctx.room.on("participant_connected", self.on_participant_connected) + self.ctx.room.on("participant_disconnected", self.on_participant_disconnected) + + async def aclose(self): + """Close all sessions and cleanup resources.""" + await utils.aio.cancel_and_wait(*self._tasks) + + await asyncio.gather( + *[self._close_session(session) for session in self._sessions.values()] + ) + + self.ctx.room.off("participant_connected", self.on_participant_connected) + self.ctx.room.off("participant_disconnected", self.on_participant_disconnected) + + def on_participant_connected(self, participant: rtc.RemoteParticipant): + """Handle new participant connection by starting transcription session.""" + if participant.identity in self._sessions: + return + + logger.info(f"starting session for {participant.identity}") + task = asyncio.create_task(self._start_session(participant)) + self._tasks.add(task) + + def on_task_done(task: asyncio.Task): + try: + self._sessions[participant.identity] = task.result() + finally: + self._tasks.discard(task) + + task.add_done_callback(on_task_done) + + def on_participant_disconnected(self, participant: rtc.RemoteParticipant): + """Handle participant disconnection by closing transcription session.""" + if (session := self._sessions.pop(participant.identity)) is None: + return + + logger.info(f"closing session for {participant.identity}") + task = asyncio.create_task(self._close_session(session)) + self._tasks.add(task) + task.add_done_callback(lambda _: self._tasks.discard(task)) + + async def _start_session(self, participant: rtc.RemoteParticipant) -> AgentSession: + """Create and start transcription session for participant.""" + if participant.identity in self._sessions: + return self._sessions[participant.identity] + + session = AgentSession( + vad=self.ctx.proc.userdata["vad"], + ) + room_io = RoomIO( + agent_session=session, + room=self.ctx.room, + participant=participant, + input_options=RoomInputOptions( + text_enabled=False, + ), + output_options=RoomOutputOptions( + transcription_enabled=True, + audio_enabled=False, + ), + ) + await room_io.start() + await session.start( + agent=Transcriber( + participant_identity=participant.identity, + ) + ) + return session + + async def _close_session(self, sess: AgentSession) -> None: + """Close and cleanup transcription session.""" + await sess.drain() + await sess.aclose() + + +async def entrypoint(ctx: JobContext): + """Initialize and run the multi-user transcriber.""" + transcriber = MultiUserTranscriber(ctx) + transcriber.start() + + await ctx.connect(auto_subscribe=AutoSubscribe.AUDIO_ONLY) + for participant in ctx.room.remote_participants.values(): + transcriber.on_participant_connected(participant) + + async def cleanup(): + await transcriber.aclose() + + ctx.add_shutdown_callback(cleanup) + + +async def handle_transcriber_job_request(job_req: JobRequest) -> None: + """Accept job if no transcriber exists in room, otherwise reject.""" + room_name = job_req.room.name + transcriber_id = f"{TRANSCRIBER_AGENT_NAME}-{room_name}" + + async with api.LiveKitAPI() as lkapi: + try: + response = await lkapi.room.list_participants( + list=api.ListParticipantsRequest(room=room_name) + ) + + transcriber_exists = any( + p.kind == rtc.ParticipantKind.PARTICIPANT_KIND_AGENT + and p.identity == transcriber_id + for p in response.participants + ) + + if transcriber_exists: + logger.info(f"Transcriber exists in {room_name} - rejecting") + await job_req.reject() + else: + logger.info(f"Accepting job for {room_name}") + await job_req.accept(identity=transcriber_id) + + except Exception: + logger.exception(f"Error processing job for {room_name}") + await job_req.reject() + + +def prewarm(proc: JobProcess): + """Preload voice activity detection model.""" + proc.userdata["vad"] = silero.VAD.load() + + +if __name__ == "__main__": + cli.run_app( + WorkerOptions( + entrypoint_fnc=entrypoint, + request_fnc=handle_transcriber_job_request, + prewarm_fnc=prewarm, + agent_name=TRANSCRIBER_AGENT_NAME, + permissions=WorkerPermissions(hidden=True), + ) + ) diff --git a/src/agents/pyproject.toml b/src/agents/pyproject.toml new file mode 100644 index 00000000..53022995 --- /dev/null +++ b/src/agents/pyproject.toml @@ -0,0 +1,51 @@ + +[project] +name = "agents" +version = "0.1.23" +requires-python = ">=3.12" +dependencies = [ + "livekit-agents==1.2.6", + "livekit-plugins-deepgram==1.2.6", + "livekit-plugins-silero==1.2.6", + "python-dotenv==1.1.1" +] + +[project.optional-dependencies] +dev = [ + "ruff==0.12.0", +] + +[build-system] +requires = ["setuptools>=61.0"] +build-backend = "setuptools.build_meta" + +[tool.ruff] +target-version = "py313" + +[tool.ruff.lint] +select = [ + "B", # flake8-bugbear + "C4", # flake8-comprehensions + "D", # pydocstyle + "E", # pycodestyle error + "F", # Pyflakes + "I", # Isort + "ISC", # flake8-implicit-str-concat + "PLC", # Pylint Convention + "PLE", # Pylint Error + "PLR", # Pylint Refactor + "PLW", # Pylint Warning + "RUF100", # Ruff unused-noqa + "S", # flake8-bandit + "T20", # flake8-print + "W", # pycodestyle warning +] + +[tool.ruff.lint.per-file-ignores] +"tests/*" = [ + "S101", # use of assert +] + +[tool.ruff.lint.pydocstyle] +# Use Google-style docstrings. +convention = "google"