audio_pipeline
This commit is contained in:
172
workers/transcribe/internal/consumer/consumer.go
Normal file
172
workers/transcribe/internal/consumer/consumer.go
Normal file
@@ -0,0 +1,172 @@
|
||||
package consumer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
amqp "github.com/rabbitmq/amqp091-go"
|
||||
|
||||
"github.com/yourorg/transcribe/internal/config"
|
||||
"github.com/yourorg/transcribe/internal/models"
|
||||
"github.com/yourorg/transcribe/internal/nexara"
|
||||
"github.com/yourorg/transcribe/internal/prompts"
|
||||
)
|
||||
|
||||
type Consumer struct {
|
||||
cfg config.Config
|
||||
ch *amqp.Channel
|
||||
nexara *nexara.Client
|
||||
prompts *prompts.Loader
|
||||
}
|
||||
|
||||
func New(cfg config.Config, ch *amqp.Channel) (*Consumer, error) {
|
||||
if err := setupTopology(ch, cfg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Consumer{
|
||||
cfg: cfg,
|
||||
ch: ch,
|
||||
nexara: nexara.New(cfg.NexaraBaseURL, cfg.NexaraAPIKey, cfg.NexaraModel, cfg.NexaraTimeout),
|
||||
prompts: prompts.New(cfg.PromptsSource, cfg.PromptsFile, cfg.PromptsBaseURL, cfg.PromptsAPIKey, cfg.PromptsSection),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func setupTopology(ch *amqp.Channel, cfg config.Config) error {
|
||||
if err := ch.ExchangeDeclare("dlx", "direct", true, false, false, false, nil); err != nil {
|
||||
return fmt.Errorf("declare dlx: %w", err)
|
||||
}
|
||||
if err := ch.ExchangeDeclare(cfg.InputExchange, "direct", true, false, false, false, nil); err != nil {
|
||||
return fmt.Errorf("declare input exchange: %w", err)
|
||||
}
|
||||
if err := ch.ExchangeDeclare(cfg.OutputExchange, "fanout", true, false, false, false, nil); err != nil {
|
||||
return fmt.Errorf("declare output exchange: %w", err)
|
||||
}
|
||||
|
||||
dlqArgs := amqp.Table{
|
||||
"x-dead-letter-exchange": "dlx",
|
||||
"x-dead-letter-routing-key": cfg.InputQueue + ".failed",
|
||||
}
|
||||
if _, err := ch.QueueDeclare(cfg.InputQueue, true, false, false, false, dlqArgs); err != nil {
|
||||
return fmt.Errorf("declare input queue: %w", err)
|
||||
}
|
||||
if _, err := ch.QueueDeclare(cfg.InputQueue+".failed", true, false, false, false, nil); err != nil {
|
||||
return fmt.Errorf("declare dlq: %w", err)
|
||||
}
|
||||
if err := ch.QueueBind(cfg.InputQueue+".failed", cfg.InputQueue+".failed", "dlx", false, nil); err != nil {
|
||||
return fmt.Errorf("bind dlq: %w", err)
|
||||
}
|
||||
if err := ch.QueueBind(cfg.InputQueue, cfg.InputRoutingKey, cfg.InputExchange, false, nil); err != nil {
|
||||
return fmt.Errorf("bind input queue: %w", err)
|
||||
}
|
||||
|
||||
for _, q := range []string{cfg.AnalyseQueue, cfg.TaggingQueue} {
|
||||
if _, err := ch.QueueDeclare(q, true, false, false, false, nil); err != nil {
|
||||
return fmt.Errorf("declare queue %s: %w", q, err)
|
||||
}
|
||||
if err := ch.QueueBind(q, "", cfg.OutputExchange, false, nil); err != nil {
|
||||
return fmt.Errorf("bind queue %s: %w", q, err)
|
||||
}
|
||||
}
|
||||
|
||||
return ch.Qos(cfg.Prefetch, 0, false)
|
||||
}
|
||||
|
||||
func (c *Consumer) Run(ctx context.Context) error {
|
||||
if err := c.ch.Confirm(false); err != nil {
|
||||
return fmt.Errorf("confirm mode: %w", err)
|
||||
}
|
||||
|
||||
msgs, err := c.ch.Consume(c.cfg.InputQueue, "", false, false, false, false, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
slog.Info("transcribe worker started", "queue", c.cfg.InputQueue, "output_exchange", c.cfg.OutputExchange)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
case d, ok := <-msgs:
|
||||
if !ok {
|
||||
return fmt.Errorf("delivery channel closed")
|
||||
}
|
||||
c.handle(ctx, d)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Consumer) handle(ctx context.Context, d amqp.Delivery) {
|
||||
var task models.AudioTask
|
||||
if err := json.Unmarshal(d.Body, &task); err != nil {
|
||||
slog.Warn("bad message", "delivery_tag", d.DeliveryTag, "error", err)
|
||||
_ = d.Nack(false, false)
|
||||
return
|
||||
}
|
||||
|
||||
slog.Info("message received", "task_id", task.TaskID, "file_path", task.FilePath, "filename", task.Filename)
|
||||
|
||||
txCtx, cancel := context.WithTimeout(ctx, c.cfg.NexaraTimeout+30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
text, lang, segments, err := c.nexara.TranscribeFile(txCtx, task.FilePath)
|
||||
if err != nil {
|
||||
slog.Warn("transcription failed", "task_id", task.TaskID, "error", err)
|
||||
_ = d.Nack(false, false)
|
||||
return
|
||||
}
|
||||
|
||||
promptList, err := c.prompts.Load(txCtx)
|
||||
if err != nil {
|
||||
slog.Warn("prompts load failed", "task_id", task.TaskID, "error", err)
|
||||
_ = d.Nack(false, false)
|
||||
return
|
||||
}
|
||||
|
||||
result := models.TranscriptionResult{
|
||||
TaskID: task.TaskID,
|
||||
Filename: task.Filename,
|
||||
FilePath: task.FilePath,
|
||||
Transcription: text,
|
||||
Language: lang,
|
||||
Segments: segments,
|
||||
Prompts: promptList,
|
||||
TranscribedAt: time.Now().Unix(),
|
||||
}
|
||||
|
||||
body, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
slog.Warn("marshal failed", "task_id", task.TaskID, "error", err)
|
||||
_ = d.Nack(false, false)
|
||||
return
|
||||
}
|
||||
|
||||
confirms := c.ch.NotifyPublish(make(chan amqp.Confirmation, 1))
|
||||
if err := c.ch.PublishWithContext(txCtx, c.cfg.OutputExchange, "", false, false, amqp.Publishing{
|
||||
ContentType: "application/json",
|
||||
Body: body,
|
||||
DeliveryMode: amqp.Persistent,
|
||||
}); err != nil {
|
||||
slog.Warn("publish failed, requeue", "task_id", task.TaskID, "error", err)
|
||||
_ = d.Nack(false, true)
|
||||
return
|
||||
}
|
||||
select {
|
||||
case confirm := <-confirms:
|
||||
if !confirm.Ack {
|
||||
slog.Warn("publish not confirmed, requeue", "task_id", task.TaskID)
|
||||
_ = d.Nack(false, true)
|
||||
return
|
||||
}
|
||||
case <-txCtx.Done():
|
||||
slog.Warn("publish timeout, requeue", "task_id", task.TaskID)
|
||||
_ = d.Nack(false, true)
|
||||
return
|
||||
}
|
||||
|
||||
slog.Info("transcribed", "task_id", task.TaskID, "language", lang, "chars", len(text), "segments", len(segments), "prompts", len(promptList))
|
||||
_ = d.Ack(false)
|
||||
}
|
||||
Reference in New Issue
Block a user