200 lines
6.1 KiB
Go
200 lines
6.1 KiB
Go
package consumer
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"log/slog"
|
|
"time"
|
|
|
|
amqp "github.com/rabbitmq/amqp091-go"
|
|
|
|
"github.com/postmet/transcribe/internal/config"
|
|
"github.com/postmet/transcribe/internal/models"
|
|
"github.com/postmet/transcribe/internal/nexara"
|
|
"github.com/postmet/transcribe/internal/pipelinestatus"
|
|
"github.com/postmet/transcribe/internal/prompts"
|
|
)
|
|
|
|
type Consumer struct {
|
|
cfg config.Config
|
|
ch *amqp.Channel
|
|
nexara *nexara.Client
|
|
prompts *prompts.Loader
|
|
}
|
|
|
|
func New(cfg config.Config, ch *amqp.Channel) (*Consumer, error) {
|
|
if err := setupTopology(ch, cfg); err != nil {
|
|
return nil, err
|
|
}
|
|
return &Consumer{
|
|
cfg: cfg,
|
|
ch: ch,
|
|
nexara: nexara.New(cfg.NexaraBaseURL, cfg.NexaraAPIKey, cfg.NexaraModel, cfg.NexaraTimeout),
|
|
prompts: prompts.New(cfg.PromptsSource, cfg.PromptsFile, cfg.PromptsBaseURL, cfg.PromptsAPIKey, cfg.PromptsSection),
|
|
}, nil
|
|
}
|
|
|
|
func setupTopology(ch *amqp.Channel, cfg config.Config) error {
|
|
if err := ch.ExchangeDeclare("dlx", "direct", true, false, false, false, nil); err != nil {
|
|
return fmt.Errorf("declare dlx: %w", err)
|
|
}
|
|
if err := ch.ExchangeDeclare(cfg.InputExchange, "direct", true, false, false, false, nil); err != nil {
|
|
return fmt.Errorf("declare input exchange: %w", err)
|
|
}
|
|
if err := ch.ExchangeDeclare(cfg.OutputExchange, "fanout", true, false, false, false, nil); err != nil {
|
|
return fmt.Errorf("declare output exchange: %w", err)
|
|
}
|
|
|
|
dlqArgs := amqp.Table{
|
|
"x-dead-letter-exchange": "dlx",
|
|
"x-dead-letter-routing-key": cfg.InputQueue + ".failed",
|
|
}
|
|
if _, err := ch.QueueDeclare(cfg.InputQueue, true, false, false, false, dlqArgs); err != nil {
|
|
return fmt.Errorf("declare input queue: %w", err)
|
|
}
|
|
if _, err := ch.QueueDeclare(cfg.InputQueue+".failed", true, false, false, false, nil); err != nil {
|
|
return fmt.Errorf("declare dlq: %w", err)
|
|
}
|
|
if err := ch.QueueBind(cfg.InputQueue+".failed", cfg.InputQueue+".failed", "dlx", false, nil); err != nil {
|
|
return fmt.Errorf("bind dlq: %w", err)
|
|
}
|
|
if err := ch.QueueBind(cfg.InputQueue, cfg.InputRoutingKey, cfg.InputExchange, false, nil); err != nil {
|
|
return fmt.Errorf("bind input queue: %w", err)
|
|
}
|
|
|
|
for _, q := range []string{cfg.AnalyseQueue, cfg.TaggingQueue} {
|
|
if _, err := ch.QueueDeclare(q, true, false, false, false, nil); err != nil {
|
|
return fmt.Errorf("declare queue %s: %w", q, err)
|
|
}
|
|
if err := ch.QueueBind(q, "", cfg.OutputExchange, false, nil); err != nil {
|
|
return fmt.Errorf("bind queue %s: %w", q, err)
|
|
}
|
|
}
|
|
if err := pipelinestatus.DeclareQueue(ch, cfg.StatusQueue); err != nil {
|
|
return fmt.Errorf("declare status queue: %w", err)
|
|
}
|
|
|
|
return ch.Qos(cfg.Prefetch, 0, false)
|
|
}
|
|
|
|
func (c *Consumer) Run(ctx context.Context) error {
|
|
if err := c.ch.Confirm(false); err != nil {
|
|
return fmt.Errorf("confirm mode: %w", err)
|
|
}
|
|
|
|
msgs, err := c.ch.Consume(c.cfg.InputQueue, "", false, false, false, false, nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
slog.Info("transcribe worker started",
|
|
"queue", c.cfg.InputQueue,
|
|
"output_exchange", c.cfg.OutputExchange,
|
|
"status_queue", c.cfg.StatusQueue,
|
|
)
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil
|
|
case d, ok := <-msgs:
|
|
if !ok {
|
|
return fmt.Errorf("delivery channel closed")
|
|
}
|
|
c.handle(ctx, d)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *Consumer) handle(ctx context.Context, d amqp.Delivery) {
|
|
var task models.AudioTask
|
|
if err := json.Unmarshal(d.Body, &task); err != nil {
|
|
slog.Warn("bad message", "delivery_tag", d.DeliveryTag, "error", err)
|
|
_ = d.Nack(false, false)
|
|
return
|
|
}
|
|
|
|
slog.Info("message received", "task_id", task.TaskID, "file_path", task.FilePath, "filename", task.Filename)
|
|
|
|
txCtx, cancel := context.WithTimeout(ctx, c.cfg.NexaraTimeout+30*time.Second)
|
|
defer cancel()
|
|
|
|
c.publishStatus(txCtx, task.TaskID, task.Filename, pipelinestatus.StatusInProgress, pipelinestatus.StageTranscribing, "")
|
|
|
|
text, lang, segments, err := c.nexara.TranscribeFile(txCtx, task.FilePath)
|
|
if err != nil {
|
|
slog.Warn("transcription failed", "task_id", task.TaskID, "error", err)
|
|
c.publishStatus(txCtx, task.TaskID, task.Filename, pipelinestatus.StatusError, pipelinestatus.StageTranscribing, err.Error())
|
|
_ = d.Nack(false, false)
|
|
return
|
|
}
|
|
|
|
promptList, err := c.prompts.Load(txCtx)
|
|
if err != nil {
|
|
slog.Warn("prompts load failed", "task_id", task.TaskID, "error", err)
|
|
c.publishStatus(txCtx, task.TaskID, task.Filename, pipelinestatus.StatusError, pipelinestatus.StageTranscribing, err.Error())
|
|
_ = d.Nack(false, false)
|
|
return
|
|
}
|
|
|
|
result := models.TranscriptionResult{
|
|
TaskID: task.TaskID,
|
|
Filename: task.Filename,
|
|
FilePath: task.FilePath,
|
|
Transcription: text,
|
|
Language: lang,
|
|
Segments: segments,
|
|
Prompts: promptList,
|
|
TranscribedAt: time.Now().Unix(),
|
|
}
|
|
|
|
body, err := json.Marshal(result)
|
|
if err != nil {
|
|
slog.Warn("marshal failed", "task_id", task.TaskID, "error", err)
|
|
_ = d.Nack(false, false)
|
|
return
|
|
}
|
|
|
|
confirms := c.ch.NotifyPublish(make(chan amqp.Confirmation, 1))
|
|
if err := c.ch.PublishWithContext(txCtx, c.cfg.OutputExchange, "", false, false, amqp.Publishing{
|
|
ContentType: "application/json",
|
|
Body: body,
|
|
DeliveryMode: amqp.Persistent,
|
|
}); err != nil {
|
|
slog.Warn("publish failed, requeue", "task_id", task.TaskID, "error", err)
|
|
_ = d.Nack(false, true)
|
|
return
|
|
}
|
|
select {
|
|
case confirm := <-confirms:
|
|
if !confirm.Ack {
|
|
slog.Warn("publish not confirmed, requeue", "task_id", task.TaskID)
|
|
_ = d.Nack(false, true)
|
|
return
|
|
}
|
|
case <-txCtx.Done():
|
|
slog.Warn("publish timeout, requeue", "task_id", task.TaskID)
|
|
_ = d.Nack(false, true)
|
|
return
|
|
}
|
|
|
|
slog.Info("transcribed", "task_id", task.TaskID, "language", lang, "chars", len(text), "segments", len(segments), "prompts", len(promptList))
|
|
_ = d.Ack(false)
|
|
}
|
|
|
|
func (c *Consumer) publishStatus(ctx context.Context, taskID, filename, status, stage, errMsg string) {
|
|
ev := pipelinestatus.Event{
|
|
TaskID: taskID,
|
|
Filename: filename,
|
|
Status: status,
|
|
Stage: stage,
|
|
}
|
|
if errMsg != "" {
|
|
ev.Error = errMsg
|
|
}
|
|
if err := pipelinestatus.Publish(ctx, c.ch, c.cfg.StatusQueue, ev); err != nil {
|
|
slog.Warn("status publish failed", "task_id", taskID, "stage", stage, "error", err)
|
|
}
|
|
}
|