statuses update
This commit is contained in:
@@ -18,6 +18,8 @@ import (
|
||||
"github.com/joho/godotenv"
|
||||
_ "github.com/jackc/pgx/v5/stdlib"
|
||||
amqp "github.com/rabbitmq/amqp091-go"
|
||||
|
||||
"github.com/postmet/analyse/internal/pipelinestatus"
|
||||
)
|
||||
|
||||
func init() {
|
||||
@@ -326,7 +328,7 @@ func saveAnalysis(ctx context.Context, db *sql.DB, task WorkerMessage, analysis
|
||||
transcription = COALESCE(NULLIF($4, ''), transcription),
|
||||
metadata = COALESCE($5::jsonb, metadata),
|
||||
updated_at = now(),
|
||||
status = CASE WHEN tagging IS NOT NULL THEN 'done' ELSE status END
|
||||
status = CASE WHEN tagging IS NOT NULL THEN 'done' ELSE 'in_progress' END
|
||||
WHERE task_id = $1
|
||||
RETURNING (analysis IS NOT NULL AND tagging IS NOT NULL)
|
||||
`, task.TaskID, string(analysis), task.Filename, task.Transcription, string(metadata)).Scan(&complete)
|
||||
@@ -336,6 +338,41 @@ func saveAnalysis(ctx context.Context, db *sql.DB, task WorkerMessage, analysis
|
||||
return complete, nil
|
||||
}
|
||||
|
||||
func setTaskInProgress(ctx context.Context, db *sql.DB, taskID, filename string) error {
|
||||
_, err := db.ExecContext(ctx, `
|
||||
INSERT INTO results (task_id, filename, status)
|
||||
VALUES ($1, NULLIF($2, ''), 'in_progress')
|
||||
ON CONFLICT (task_id) DO UPDATE SET
|
||||
status = 'in_progress',
|
||||
filename = COALESCE(NULLIF(EXCLUDED.filename, ''), results.filename),
|
||||
updated_at = now()
|
||||
`, taskID, filename)
|
||||
return err
|
||||
}
|
||||
|
||||
func setTaskError(ctx context.Context, db *sql.DB, taskID string) error {
|
||||
_, err := db.ExecContext(ctx, `
|
||||
INSERT INTO results (task_id, status) VALUES ($1, 'error')
|
||||
ON CONFLICT (task_id) DO UPDATE SET status = 'error', updated_at = now()
|
||||
`, taskID)
|
||||
return err
|
||||
}
|
||||
|
||||
func publishStatus(ctx context.Context, ch *amqp.Channel, queue, taskID, filename, status, stage, errMsg string) {
|
||||
ev := pipelinestatus.Event{
|
||||
TaskID: taskID,
|
||||
Filename: filename,
|
||||
Status: status,
|
||||
Stage: stage,
|
||||
}
|
||||
if errMsg != "" {
|
||||
ev.Error = errMsg
|
||||
}
|
||||
if err := pipelinestatus.Publish(ctx, ch, queue, ev); err != nil {
|
||||
slog.Warn("status publish failed", "worker", "analyse", "task_id", taskID, "stage", stage, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ===================== MAIN =====================
|
||||
|
||||
func loadDotenv() {
|
||||
@@ -360,6 +397,7 @@ func main() {
|
||||
apiURL := getEnv("YANDEX_API_URL", "https://ai.api.cloud.yandex.net/v1/chat/completions")
|
||||
inputQueue := getEnv("ANALYSE_QUEUE", "analyse")
|
||||
finalQueue := getEnv("FINAL_QUEUE", "final")
|
||||
statusQueue := getEnv("STATUS_QUEUE", "pipeline.status")
|
||||
|
||||
if token == "" {
|
||||
slog.Error("YANDEX_API_KEY is required")
|
||||
@@ -397,6 +435,10 @@ func main() {
|
||||
slog.Error("declare queue failed", "queue", finalQueue, "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
if err := pipelinestatus.DeclareQueue(ch, statusQueue); err != nil {
|
||||
slog.Error("declare status queue failed", "queue", statusQueue, "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
ch.Qos(1, 0, false)
|
||||
|
||||
msgs, err := ch.Consume(inputQueue, "", false, false, false, false, nil)
|
||||
@@ -404,7 +446,7 @@ func main() {
|
||||
slog.Error("consume failed", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
slog.Info("worker started", "worker", "analyse", "queue", inputQueue, "model", model)
|
||||
slog.Info("worker started", "worker", "analyse", "queue", inputQueue, "status_queue", statusQueue, "model", model)
|
||||
|
||||
for d := range msgs {
|
||||
taskStart := time.Now()
|
||||
@@ -449,9 +491,18 @@ func main() {
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||
|
||||
if err := setTaskInProgress(ctx, db, task.TaskID, task.Filename); err != nil {
|
||||
slog.Warn("set in_progress failed", "worker", "analyse", "task_id", task.TaskID, "error", err)
|
||||
}
|
||||
publishStatus(ctx, ch, statusQueue, task.TaskID, task.Filename,
|
||||
pipelinestatus.StatusInProgress, pipelinestatus.StageAnalysing, "")
|
||||
|
||||
result, stats, err := runAnalysis(ctx, apiURL, model, task.TaskID, task.Transcription, task.Prompts)
|
||||
if err != nil {
|
||||
cancel()
|
||||
_ = setTaskError(context.Background(), db, task.TaskID)
|
||||
publishStatus(context.Background(), ch, statusQueue, task.TaskID, task.Filename,
|
||||
pipelinestatus.StatusError, pipelinestatus.StageAnalysing, err.Error())
|
||||
slog.Warn("task failed, discarded",
|
||||
"worker", "analyse", "task_id", task.TaskID,
|
||||
"llm_calls_done", stats.LLMCalls,
|
||||
@@ -482,7 +533,7 @@ func main() {
|
||||
}
|
||||
|
||||
if complete {
|
||||
notifyFinal(ctx, ch, db, finalQueue, task.TaskID, "analyse")
|
||||
notifyFinal(ctx, ch, db, finalQueue, statusQueue, task.TaskID, task.Filename, "analyse")
|
||||
slog.Info("task complete", append(taskAttrs, "was_last", "analyse")...)
|
||||
} else {
|
||||
slog.Info("task partial", append(taskAttrs, "waiting_for", "tagging")...)
|
||||
@@ -549,12 +600,14 @@ func loadFinalPayload(ctx context.Context, db *sql.DB, taskID string) ([]byte, e
|
||||
return json.Marshal(msg)
|
||||
}
|
||||
|
||||
func notifyFinal(ctx context.Context, ch *amqp.Channel, db *sql.DB, queue, taskID, worker string) {
|
||||
func notifyFinal(ctx context.Context, ch *amqp.Channel, db *sql.DB, queue, statusQueue, taskID, filename, worker string) {
|
||||
body, err := loadFinalPayload(ctx, db, taskID)
|
||||
if err != nil {
|
||||
slog.Warn("load final payload failed", "worker", worker, "task_id", taskID, "error", err)
|
||||
return
|
||||
}
|
||||
publishStatus(ctx, ch, statusQueue, taskID, filename,
|
||||
pipelinestatus.StatusDone, pipelinestatus.StageCompleted, "")
|
||||
if err := ch.PublishWithContext(ctx, "", queue, false, false,
|
||||
amqp.Publishing{
|
||||
ContentType: "application/json",
|
||||
|
||||
164
workers/analyse/cmd/analyse/main_test.go
Normal file
164
workers/analyse/cmd/analyse/main_test.go
Normal file
@@ -0,0 +1,164 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestTruncate(t *testing.T) {
|
||||
if got := truncate("short", 10); got != "short" {
|
||||
t.Fatalf("got %q", got)
|
||||
}
|
||||
long := strings.Repeat("a", 20)
|
||||
got := truncate(long, 10)
|
||||
if len(got) != 13 || !strings.HasSuffix(got, "...") {
|
||||
t.Fatalf("got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenFingerprint(t *testing.T) {
|
||||
if tokenFingerprint("short") != "***" {
|
||||
t.Fatal("short token should be masked")
|
||||
}
|
||||
fp := tokenFingerprint("abcdefghijklmnop")
|
||||
if !strings.HasPrefix(fp, "abcdefgh") || !strings.HasSuffix(fp, "mnop") {
|
||||
t.Fatalf("got %q", fp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPromptQuery(t *testing.T) {
|
||||
p := Prompt{Name: "behavioral", Prompt: "Оцени звонок"}
|
||||
q := buildPromptQuery("текст транскрипции", p)
|
||||
if !strings.Contains(q, "Оцени звонок") {
|
||||
t.Fatal("missing prompt text")
|
||||
}
|
||||
if !strings.Contains(q, "текст транскрипции") {
|
||||
t.Fatal("missing transcription")
|
||||
}
|
||||
if !strings.Contains(q, "=== ТРАНСКРИПЦИЯ ===") {
|
||||
t.Fatal("missing section header")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunAnalysis(t *testing.T) {
|
||||
t.Setenv("YANDEX_API_KEY", "test-token")
|
||||
|
||||
var calls int
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
calls++
|
||||
if auth := r.Header.Get("Authorization"); auth != "Bearer test-token" {
|
||||
t.Fatalf("auth: got %q", auth)
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"choices": []map[string]any{
|
||||
{"message": map[string]any{"content": `{"ok":true}`}},
|
||||
},
|
||||
"usage": map[string]any{
|
||||
"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15,
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
prompts := []Prompt{
|
||||
{Name: "behavioral", Prompt: "p1"},
|
||||
{Name: "client_data", Prompt: "p2"},
|
||||
}
|
||||
result, stats, err := runAnalysis(context.Background(), srv.URL, "model", "task-1", "транскрипт", prompts)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if calls != 2 {
|
||||
t.Fatalf("llm calls: got %d, want 2", calls)
|
||||
}
|
||||
if len(result) != 2 {
|
||||
t.Fatalf("result keys: %d", len(result))
|
||||
}
|
||||
if stats.LLMCalls != 2 || stats.TotalTokens != 30 {
|
||||
t.Fatalf("stats: %+v", stats)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunAnalysisSkipsEmptyNames(t *testing.T) {
|
||||
t.Setenv("YANDEX_API_KEY", "k")
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"choices": []map[string]any{{"message": map[string]any{"content": `{"x":1}`}}},
|
||||
})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
result, stats, err := runAnalysis(context.Background(), srv.URL, "m", "t", "text", []Prompt{
|
||||
{Name: "", Prompt: "skip"},
|
||||
{Name: "valid", Prompt: "go"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(result) != 1 {
|
||||
t.Fatalf("want 1 result, got %d", len(result))
|
||||
}
|
||||
if stats.LLMCalls != 1 {
|
||||
t.Fatalf("llm calls: %d", stats.LLMCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractFilePath(t *testing.T) {
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"task_id": "01HX",
|
||||
"file_path": "/data/storage/processing/01HX.wav",
|
||||
})
|
||||
if got := extractFilePath(body); got != "/data/storage/processing/01HX.wav" {
|
||||
t.Fatalf("got %q", got)
|
||||
}
|
||||
if extractFilePath([]byte("not json")) != "" {
|
||||
t.Fatal("invalid json should return empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteProcessingFile(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
processingDir := filepath.Join(dir, "processing")
|
||||
if err := os.MkdirAll(processingDir, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
path := filepath.Join(processingDir, "01TASK.wav")
|
||||
if err := os.WriteFile(path, []byte("audio"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
deleteProcessingFile(path, "01TASK", "analyse")
|
||||
if _, err := os.Stat(path); !os.IsNotExist(err) {
|
||||
t.Fatal("file should be deleted")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteProcessingFileRejectsOutsideProcessing(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "incoming", "file.wav")
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(path, []byte("x"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
deleteProcessingFile(path, "t", "analyse")
|
||||
if _, err := os.Stat(path); err != nil {
|
||||
t.Fatal("file outside processing must not be deleted")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccumulateUsage(t *testing.T) {
|
||||
stats := &analysisStats{}
|
||||
accumulateUsage(stats, &llmCallResult{Usage: &tokenUsage{TotalTokens: 10, PromptTokens: 6, CompletionTokens: 4}})
|
||||
accumulateUsage(stats, &llmCallResult{Usage: &tokenUsage{TotalTokens: 5, PromptTokens: 3, CompletionTokens: 2}})
|
||||
if stats.LLMCalls != 2 || stats.TotalTokens != 15 {
|
||||
t.Fatalf("stats: %+v", stats)
|
||||
}
|
||||
}
|
||||
57
workers/analyse/internal/pipelinestatus/status.go
Normal file
57
workers/analyse/internal/pipelinestatus/status.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package pipelinestatus
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
amqp "github.com/rabbitmq/amqp091-go"
|
||||
)
|
||||
|
||||
const (
|
||||
StatusPending = "pending"
|
||||
StatusInProgress = "in_progress"
|
||||
StatusDone = "done"
|
||||
StatusError = "error"
|
||||
)
|
||||
|
||||
const (
|
||||
StageQueued = "queued"
|
||||
StageTranscribing = "transcribing"
|
||||
StageAnalysing = "analysing"
|
||||
StageTagging = "tagging"
|
||||
StageCompleted = "completed"
|
||||
)
|
||||
|
||||
type Event struct {
|
||||
TaskID string `json:"task_id"`
|
||||
Filename string `json:"filename,omitempty"`
|
||||
Status string `json:"status"`
|
||||
Stage string `json:"stage"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
}
|
||||
|
||||
func DeclareQueue(ch *amqp.Channel, queue string) error {
|
||||
_, err := ch.QueueDeclare(queue, true, false, false, false, nil)
|
||||
return err
|
||||
}
|
||||
|
||||
func Publish(ctx context.Context, ch *amqp.Channel, queue string, ev Event) error {
|
||||
if ev.Timestamp == 0 {
|
||||
ev.Timestamp = time.Now().Unix()
|
||||
}
|
||||
body, err := json.Marshal(ev)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := ch.PublishWithContext(ctx, "", queue, false, false, amqp.Publishing{
|
||||
ContentType: "application/json",
|
||||
Body: body,
|
||||
DeliveryMode: amqp.Persistent,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("publish status: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -18,6 +18,8 @@ import (
|
||||
"github.com/joho/godotenv"
|
||||
_ "github.com/jackc/pgx/v5/stdlib"
|
||||
amqp "github.com/rabbitmq/amqp091-go"
|
||||
|
||||
"github.com/postmet/tagging/internal/pipelinestatus"
|
||||
)
|
||||
|
||||
func init() {
|
||||
@@ -381,7 +383,7 @@ func saveTagging(ctx context.Context, db *sql.DB, taskID, filename, transcriptio
|
||||
filename = COALESCE(NULLIF($3, ''), filename),
|
||||
transcription = COALESCE(NULLIF($4, ''), transcription),
|
||||
updated_at = now(),
|
||||
status = CASE WHEN analysis IS NOT NULL THEN 'done' ELSE status END
|
||||
status = CASE WHEN analysis IS NOT NULL THEN 'done' ELSE 'in_progress' END
|
||||
WHERE task_id = $1
|
||||
RETURNING (analysis IS NOT NULL AND tagging IS NOT NULL)
|
||||
`, taskID, string(tagging), filename, transcription).Scan(&complete)
|
||||
@@ -391,6 +393,41 @@ func saveTagging(ctx context.Context, db *sql.DB, taskID, filename, transcriptio
|
||||
return complete, nil
|
||||
}
|
||||
|
||||
func setTaskInProgress(ctx context.Context, db *sql.DB, taskID, filename string) error {
|
||||
_, err := db.ExecContext(ctx, `
|
||||
INSERT INTO results (task_id, filename, status)
|
||||
VALUES ($1, NULLIF($2, ''), 'in_progress')
|
||||
ON CONFLICT (task_id) DO UPDATE SET
|
||||
status = 'in_progress',
|
||||
filename = COALESCE(NULLIF(EXCLUDED.filename, ''), results.filename),
|
||||
updated_at = now()
|
||||
`, taskID, filename)
|
||||
return err
|
||||
}
|
||||
|
||||
func setTaskError(ctx context.Context, db *sql.DB, taskID string) error {
|
||||
_, err := db.ExecContext(ctx, `
|
||||
INSERT INTO results (task_id, status) VALUES ($1, 'error')
|
||||
ON CONFLICT (task_id) DO UPDATE SET status = 'error', updated_at = now()
|
||||
`, taskID)
|
||||
return err
|
||||
}
|
||||
|
||||
func publishStatus(ctx context.Context, ch *amqp.Channel, queue, taskID, filename, status, stage, errMsg string) {
|
||||
ev := pipelinestatus.Event{
|
||||
TaskID: taskID,
|
||||
Filename: filename,
|
||||
Status: status,
|
||||
Stage: stage,
|
||||
}
|
||||
if errMsg != "" {
|
||||
ev.Error = errMsg
|
||||
}
|
||||
if err := pipelinestatus.Publish(ctx, ch, queue, ev); err != nil {
|
||||
slog.Warn("status publish failed", "worker", "tagging", "task_id", taskID, "stage", stage, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ===================== MAIN =====================
|
||||
|
||||
func loadDotenv() {
|
||||
@@ -414,6 +451,7 @@ func main() {
|
||||
model := os.Getenv("YANDEX_MODEL")
|
||||
inputQueue := getenv("TAGGING_QUEUE", "tagging")
|
||||
finalQueue := getenv("FINAL_QUEUE", "final")
|
||||
statusQueue := getenv("STATUS_QUEUE", "pipeline.status")
|
||||
|
||||
if token == "" {
|
||||
slog.Error("YANDEX_API_KEY is required")
|
||||
@@ -447,6 +485,10 @@ func main() {
|
||||
slog.Error("declare queue failed", "queue", finalQueue, "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
if err := pipelinestatus.DeclareQueue(ch, statusQueue); err != nil {
|
||||
slog.Error("declare status queue failed", "queue", statusQueue, "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
ch.Qos(1, 0, false)
|
||||
|
||||
msgs, err := ch.Consume(inputQueue, "", false, false, false, false, nil)
|
||||
@@ -454,7 +496,7 @@ func main() {
|
||||
slog.Error("consume failed", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
slog.Info("worker started", "worker", "tagging", "queue", inputQueue, "model", model)
|
||||
slog.Info("worker started", "worker", "tagging", "queue", inputQueue, "status_queue", statusQueue, "model", model)
|
||||
|
||||
for d := range msgs {
|
||||
taskStart := time.Now()
|
||||
@@ -485,9 +527,18 @@ func main() {
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
|
||||
|
||||
if err := setTaskInProgress(ctx, db, task.TaskID, task.Filename); err != nil {
|
||||
slog.Warn("set in_progress failed", "worker", "tagging", "task_id", task.TaskID, "error", err)
|
||||
}
|
||||
publishStatus(ctx, ch, statusQueue, task.TaskID, task.Filename,
|
||||
pipelinestatus.StatusInProgress, pipelinestatus.StageTagging, "")
|
||||
|
||||
result, err := classify(ctx, task.TaskID, model, task.Transcription)
|
||||
if err != nil {
|
||||
cancel()
|
||||
_ = setTaskError(context.Background(), db, task.TaskID)
|
||||
publishStatus(context.Background(), ch, statusQueue, task.TaskID, task.Filename,
|
||||
pipelinestatus.StatusError, pipelinestatus.StageTagging, err.Error())
|
||||
slog.Warn("task failed, discarded",
|
||||
"worker", "tagging", "task_id", task.TaskID,
|
||||
"llm_calls", 1, "error", err)
|
||||
@@ -506,7 +557,7 @@ func main() {
|
||||
}
|
||||
|
||||
if complete {
|
||||
notifyFinal(ctx, ch, db, finalQueue, task.TaskID, "tagging")
|
||||
notifyFinal(ctx, ch, db, finalQueue, statusQueue, task.TaskID, task.Filename, "tagging")
|
||||
slog.Info("task complete", "worker", "tagging", "task_id", task.TaskID,
|
||||
"was_last", "tagging", "L1", result.L1,
|
||||
"llm_calls", 1, "duration_ms", time.Since(taskStart).Milliseconds())
|
||||
@@ -591,12 +642,14 @@ func loadFinalPayload(ctx context.Context, db *sql.DB, taskID string) ([]byte, e
|
||||
return json.Marshal(msg)
|
||||
}
|
||||
|
||||
func notifyFinal(ctx context.Context, ch *amqp.Channel, db *sql.DB, queue, taskID, worker string) {
|
||||
func notifyFinal(ctx context.Context, ch *amqp.Channel, db *sql.DB, queue, statusQueue, taskID, filename, worker string) {
|
||||
body, err := loadFinalPayload(ctx, db, taskID)
|
||||
if err != nil {
|
||||
slog.Warn("load final payload failed", "worker", worker, "task_id", taskID, "error", err)
|
||||
return
|
||||
}
|
||||
publishStatus(ctx, ch, statusQueue, taskID, filename,
|
||||
pipelinestatus.StatusDone, pipelinestatus.StageCompleted, "")
|
||||
if err := ch.PublishWithContext(ctx, "", queue, false, false,
|
||||
amqp.Publishing{
|
||||
ContentType: "application/json",
|
||||
|
||||
123
workers/tagging/cmd/tagging/main_test.go
Normal file
123
workers/tagging/cmd/tagging/main_test.go
Normal file
@@ -0,0 +1,123 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestTruncate(t *testing.T) {
|
||||
if got := truncate("ok", 5); got != "ok" {
|
||||
t.Fatalf("got %q", got)
|
||||
}
|
||||
got := truncate(strings.Repeat("x", 30), 5)
|
||||
if len(got) != 8 {
|
||||
t.Fatalf("got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenFingerprint(t *testing.T) {
|
||||
if tokenFingerprint("x") != "***" {
|
||||
t.Fatal("expected mask for short token")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPromptContainsTranscription(t *testing.T) {
|
||||
text := "клиент спрашивает где груз"
|
||||
p := buildPrompt(text)
|
||||
if !strings.Contains(p, text) {
|
||||
t.Fatal("prompt must include transcription")
|
||||
}
|
||||
if !strings.Contains(p, `"L1"`) {
|
||||
t.Fatal("prompt must describe JSON output")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassify(t *testing.T) {
|
||||
t.Setenv("YANDEX_API_KEY", "test-key")
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if auth := r.Header.Get("Authorization"); auth != "Bearer test-key" {
|
||||
t.Fatalf("auth: got %q", auth)
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"choices": []map[string]any{
|
||||
{"message": map[string]any{"content": `{
|
||||
"L1":"tracking",
|
||||
"L2":"location_request",
|
||||
"L3":"",
|
||||
"risk_level":"low",
|
||||
"has_action_items":false,
|
||||
"has_deadline":false
|
||||
}`}},
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
t.Setenv("YANDEX_API_URL", srv.URL)
|
||||
|
||||
result, err := classify(context.Background(), "task-1", "model", "где мой груз")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if result.L1 != "tracking" || result.L2 != "location_request" {
|
||||
t.Fatalf("unexpected result: %+v", result)
|
||||
}
|
||||
if result.RiskLevel != "low" {
|
||||
t.Fatalf("risk_level: %q", result.RiskLevel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyInvalidJSON(t *testing.T) {
|
||||
t.Setenv("YANDEX_API_KEY", "k")
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"choices": []map[string]any{{"message": map[string]any{"content": "not-json"}}},
|
||||
})
|
||||
}))
|
||||
defer srv.Close()
|
||||
t.Setenv("YANDEX_API_URL", srv.URL)
|
||||
|
||||
_, err := classify(context.Background(), "t", "m", "text")
|
||||
if err == nil {
|
||||
t.Fatal("expected parse error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractFilePath(t *testing.T) {
|
||||
body := []byte(`{"file_path":"/data/storage/processing/x.wav"}`)
|
||||
if got := extractFilePath(body); got != "/data/storage/processing/x.wav" {
|
||||
t.Fatalf("got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteProcessingFile(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "processing", "01.wav")
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(path, []byte("a"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
deleteProcessingFile(path, "01", "tagging")
|
||||
if _, err := os.Stat(path); !os.IsNotExist(err) {
|
||||
t.Fatal("expected file removed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetenv(t *testing.T) {
|
||||
t.Setenv("TAGGING_QUEUE", "custom")
|
||||
if got := getenv("TAGGING_QUEUE", "tagging"); got != "custom" {
|
||||
t.Fatalf("got %q", got)
|
||||
}
|
||||
if got := getenv("UNSET_VAR_XYZ", "default"); got != "default" {
|
||||
t.Fatalf("got %q", got)
|
||||
}
|
||||
}
|
||||
57
workers/tagging/internal/pipelinestatus/status.go
Normal file
57
workers/tagging/internal/pipelinestatus/status.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package pipelinestatus
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
amqp "github.com/rabbitmq/amqp091-go"
|
||||
)
|
||||
|
||||
const (
|
||||
StatusPending = "pending"
|
||||
StatusInProgress = "in_progress"
|
||||
StatusDone = "done"
|
||||
StatusError = "error"
|
||||
)
|
||||
|
||||
const (
|
||||
StageQueued = "queued"
|
||||
StageTranscribing = "transcribing"
|
||||
StageAnalysing = "analysing"
|
||||
StageTagging = "tagging"
|
||||
StageCompleted = "completed"
|
||||
)
|
||||
|
||||
type Event struct {
|
||||
TaskID string `json:"task_id"`
|
||||
Filename string `json:"filename,omitempty"`
|
||||
Status string `json:"status"`
|
||||
Stage string `json:"stage"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
}
|
||||
|
||||
func DeclareQueue(ch *amqp.Channel, queue string) error {
|
||||
_, err := ch.QueueDeclare(queue, true, false, false, false, nil)
|
||||
return err
|
||||
}
|
||||
|
||||
func Publish(ctx context.Context, ch *amqp.Channel, queue string, ev Event) error {
|
||||
if ev.Timestamp == 0 {
|
||||
ev.Timestamp = time.Now().Unix()
|
||||
}
|
||||
body, err := json.Marshal(ev)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := ch.PublishWithContext(ctx, "", queue, false, false, amqp.Publishing{
|
||||
ContentType: "application/json",
|
||||
Body: body,
|
||||
DeliveryMode: amqp.Persistent,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("publish status: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -15,6 +15,7 @@ type Config struct {
|
||||
InputExchange string
|
||||
InputRoutingKey string
|
||||
Prefetch int
|
||||
StatusQueue string
|
||||
|
||||
NexaraBaseURL string
|
||||
NexaraAPIKey string
|
||||
@@ -38,6 +39,7 @@ func Load() Config {
|
||||
InputExchange: getEnv("RABBITMQ_EXCHANGE", "audio_pipeline"),
|
||||
InputRoutingKey: getEnv("RABBITMQ_ROUTING_KEY", "audio.new"),
|
||||
Prefetch: getInt("PREFETCH", 1),
|
||||
StatusQueue: getEnv("STATUS_QUEUE", "pipeline.status"),
|
||||
|
||||
NexaraBaseURL: getEnv("NEXARA_BASE_URL", "https://api.nexara.ru"),
|
||||
NexaraAPIKey: os.Getenv("NEXARA_API_KEY"),
|
||||
|
||||
51
workers/transcribe/internal/config/config_test.go
Normal file
51
workers/transcribe/internal/config/config_test.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestLoadDefaults(t *testing.T) {
|
||||
t.Setenv("RABBITMQ_URL", "")
|
||||
t.Setenv("STATUS_QUEUE", "")
|
||||
t.Setenv("NEXARA_TIMEOUT", "")
|
||||
|
||||
cfg := Load()
|
||||
if cfg.InputQueue != "transcribe.tasks" {
|
||||
t.Fatalf("InputQueue: got %q", cfg.InputQueue)
|
||||
}
|
||||
if cfg.StatusQueue != "pipeline.status" {
|
||||
t.Fatalf("StatusQueue: got %q", cfg.StatusQueue)
|
||||
}
|
||||
if cfg.NexaraTimeout != 10*time.Minute {
|
||||
t.Fatalf("NexaraTimeout: got %v", cfg.NexaraTimeout)
|
||||
}
|
||||
if cfg.PromptsSection != 1 {
|
||||
t.Fatalf("PromptsSection: got %d", cfg.PromptsSection)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadFromEnv(t *testing.T) {
|
||||
t.Setenv("INPUT_QUEUE", "custom.tasks")
|
||||
t.Setenv("STATUS_QUEUE", "status.events")
|
||||
t.Setenv("PREFETCH", "4")
|
||||
t.Setenv("NEXARA_TIMEOUT", "2m")
|
||||
t.Setenv("PROMPTS_SECTION", "2")
|
||||
|
||||
cfg := Load()
|
||||
if cfg.InputQueue != "custom.tasks" {
|
||||
t.Fatalf("InputQueue: got %q", cfg.InputQueue)
|
||||
}
|
||||
if cfg.StatusQueue != "status.events" {
|
||||
t.Fatalf("StatusQueue: got %q", cfg.StatusQueue)
|
||||
}
|
||||
if cfg.Prefetch != 4 {
|
||||
t.Fatalf("Prefetch: got %d", cfg.Prefetch)
|
||||
}
|
||||
if cfg.NexaraTimeout != 2*time.Minute {
|
||||
t.Fatalf("NexaraTimeout: got %v", cfg.NexaraTimeout)
|
||||
}
|
||||
if cfg.PromptsSection != 2 {
|
||||
t.Fatalf("PromptsSection: got %d", cfg.PromptsSection)
|
||||
}
|
||||
}
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/postmet/transcribe/internal/config"
|
||||
"github.com/postmet/transcribe/internal/models"
|
||||
"github.com/postmet/transcribe/internal/nexara"
|
||||
"github.com/postmet/transcribe/internal/pipelinestatus"
|
||||
"github.com/postmet/transcribe/internal/prompts"
|
||||
)
|
||||
|
||||
@@ -70,6 +71,9 @@ func setupTopology(ch *amqp.Channel, cfg config.Config) error {
|
||||
return fmt.Errorf("bind queue %s: %w", q, err)
|
||||
}
|
||||
}
|
||||
if err := pipelinestatus.DeclareQueue(ch, cfg.StatusQueue); err != nil {
|
||||
return fmt.Errorf("declare status queue: %w", err)
|
||||
}
|
||||
|
||||
return ch.Qos(cfg.Prefetch, 0, false)
|
||||
}
|
||||
@@ -84,7 +88,11 @@ func (c *Consumer) Run(ctx context.Context) error {
|
||||
return err
|
||||
}
|
||||
|
||||
slog.Info("transcribe worker started", "queue", c.cfg.InputQueue, "output_exchange", c.cfg.OutputExchange)
|
||||
slog.Info("transcribe worker started",
|
||||
"queue", c.cfg.InputQueue,
|
||||
"output_exchange", c.cfg.OutputExchange,
|
||||
"status_queue", c.cfg.StatusQueue,
|
||||
)
|
||||
|
||||
for {
|
||||
select {
|
||||
@@ -112,9 +120,12 @@ func (c *Consumer) handle(ctx context.Context, d amqp.Delivery) {
|
||||
txCtx, cancel := context.WithTimeout(ctx, c.cfg.NexaraTimeout+30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
c.publishStatus(txCtx, task.TaskID, task.Filename, pipelinestatus.StatusInProgress, pipelinestatus.StageTranscribing, "")
|
||||
|
||||
text, lang, segments, err := c.nexara.TranscribeFile(txCtx, task.FilePath)
|
||||
if err != nil {
|
||||
slog.Warn("transcription failed", "task_id", task.TaskID, "error", err)
|
||||
c.publishStatus(txCtx, task.TaskID, task.Filename, pipelinestatus.StatusError, pipelinestatus.StageTranscribing, err.Error())
|
||||
_ = d.Nack(false, false)
|
||||
return
|
||||
}
|
||||
@@ -122,6 +133,7 @@ func (c *Consumer) handle(ctx context.Context, d amqp.Delivery) {
|
||||
promptList, err := c.prompts.Load(txCtx)
|
||||
if err != nil {
|
||||
slog.Warn("prompts load failed", "task_id", task.TaskID, "error", err)
|
||||
c.publishStatus(txCtx, task.TaskID, task.Filename, pipelinestatus.StatusError, pipelinestatus.StageTranscribing, err.Error())
|
||||
_ = d.Nack(false, false)
|
||||
return
|
||||
}
|
||||
@@ -170,3 +182,18 @@ func (c *Consumer) handle(ctx context.Context, d amqp.Delivery) {
|
||||
slog.Info("transcribed", "task_id", task.TaskID, "language", lang, "chars", len(text), "segments", len(segments), "prompts", len(promptList))
|
||||
_ = d.Ack(false)
|
||||
}
|
||||
|
||||
func (c *Consumer) publishStatus(ctx context.Context, taskID, filename, status, stage, errMsg string) {
|
||||
ev := pipelinestatus.Event{
|
||||
TaskID: taskID,
|
||||
Filename: filename,
|
||||
Status: status,
|
||||
Stage: stage,
|
||||
}
|
||||
if errMsg != "" {
|
||||
ev.Error = errMsg
|
||||
}
|
||||
if err := pipelinestatus.Publish(ctx, c.ch, c.cfg.StatusQueue, ev); err != nil {
|
||||
slog.Warn("status publish failed", "task_id", taskID, "stage", stage, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
58
workers/transcribe/internal/models/models_test.go
Normal file
58
workers/transcribe/internal/models/models_test.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAudioTaskRoundTrip(t *testing.T) {
|
||||
src := AudioTask{
|
||||
TaskID: "01HX",
|
||||
FilePath: "/data/processing/01HX.wav",
|
||||
Filename: "call.wav",
|
||||
Size: 2048,
|
||||
CreatedAt: 1717843200,
|
||||
}
|
||||
body, err := json.Marshal(src)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var got AudioTask
|
||||
if err := json.Unmarshal(body, &got); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got != src {
|
||||
t.Fatalf("round-trip mismatch: %+v vs %+v", got, src)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTranscriptionResultRoundTrip(t *testing.T) {
|
||||
src := TranscriptionResult{
|
||||
TaskID: "01HX",
|
||||
Filename: "call.wav",
|
||||
FilePath: "/data/processing/01HX.wav",
|
||||
Transcription: "текст звонка",
|
||||
Language: "ru",
|
||||
Segments: []Segment{
|
||||
{Start: 0, End: 1.2, Text: "текст"},
|
||||
},
|
||||
Prompts: []Prompt{
|
||||
{ID: 1, IDSection: 1, Name: "behavioral", Prompt: "analyze", DtCreate: "2026-01-01"},
|
||||
},
|
||||
TranscribedAt: 1717843200,
|
||||
}
|
||||
body, err := json.Marshal(src)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var got TranscriptionResult
|
||||
if err := json.Unmarshal(body, &got); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got.TaskID != src.TaskID || got.Transcription != src.Transcription {
|
||||
t.Fatalf("mismatch: %+v", got)
|
||||
}
|
||||
if len(got.Segments) != 1 || len(got.Prompts) != 1 {
|
||||
t.Fatalf("nested fields lost: %+v", got)
|
||||
}
|
||||
}
|
||||
118
workers/transcribe/internal/nexara/nexara_test.go
Normal file
118
workers/transcribe/internal/nexara/nexara_test.go
Normal file
@@ -0,0 +1,118 @@
|
||||
package nexara
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestTranscribeFileSuccess(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
t.Fatalf("method: got %s", r.Method)
|
||||
}
|
||||
if !strings.HasPrefix(r.Header.Get("Content-Type"), "multipart/form-data") {
|
||||
t.Fatalf("content-type: got %s", r.Header.Get("Content-Type"))
|
||||
}
|
||||
if auth := r.Header.Get("Authorization"); auth != "Bearer nexara-key" {
|
||||
t.Fatalf("auth: got %q", auth)
|
||||
}
|
||||
if err := r.ParseMultipartForm(1 << 20); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if r.FormValue("model") != "whisper-1" {
|
||||
t.Fatalf("model: got %q", r.FormValue("model"))
|
||||
}
|
||||
if r.FormValue("response_format") != "json" {
|
||||
t.Fatalf("response_format: got %q", r.FormValue("response_format"))
|
||||
}
|
||||
file, _, err := r.FormFile("file")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer file.Close()
|
||||
buf := make([]byte, 64)
|
||||
n, _ := file.Read(buf)
|
||||
if string(buf[:n]) != "audio-bytes" {
|
||||
t.Fatalf("file content: got %q", string(buf[:n]))
|
||||
}
|
||||
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"text": "привет мир",
|
||||
"language": "ru",
|
||||
"segments": []map[string]any{
|
||||
{"start": 0.0, "end": 1.5, "text": "привет"},
|
||||
{"start": 1.5, "end": 3.0, "text": "мир"},
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
dir := t.TempDir()
|
||||
audioPath := filepath.Join(dir, "test.wav")
|
||||
if err := os.WriteFile(audioPath, []byte("audio-bytes"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
client := New(srv.URL, "nexara-key", "whisper-1", 5*time.Second)
|
||||
text, lang, segments, err := client.TranscribeFile(context.Background(), audioPath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if text != "привет мир" {
|
||||
t.Fatalf("text: got %q", text)
|
||||
}
|
||||
if lang != "ru" {
|
||||
t.Fatalf("language: got %q", lang)
|
||||
}
|
||||
if len(segments) != 2 {
|
||||
t.Fatalf("segments: got %d", len(segments))
|
||||
}
|
||||
if segments[0].Text != "привет" || segments[1].End != 3.0 {
|
||||
t.Fatalf("unexpected segments: %+v", segments)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTranscribeFileAPIError(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
_, _ = w.Write([]byte("invalid key"))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
dir := t.TempDir()
|
||||
audioPath := filepath.Join(dir, "test.wav")
|
||||
if err := os.WriteFile(audioPath, []byte("x"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
client := New(srv.URL, "bad", "", 5*time.Second)
|
||||
_, _, _, err := client.TranscribeFile(context.Background(), audioPath)
|
||||
if err == nil {
|
||||
t.Fatal("expected error on 401")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "401") {
|
||||
t.Fatalf("error should mention status: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTranscribeFileMissingFile(t *testing.T) {
|
||||
client := New("http://localhost", "key", "", time.Second)
|
||||
_, _, _, err := client.TranscribeFile(context.Background(), filepath.Join(t.TempDir(), "missing.wav"))
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewTrimsBaseURL(t *testing.T) {
|
||||
c := New("https://api.example.com/", "k", "m", time.Second)
|
||||
if !strings.HasSuffix(c.apiURL, "/api/v1/audio/transcriptions") {
|
||||
t.Fatalf("apiURL: got %s", c.apiURL)
|
||||
}
|
||||
}
|
||||
57
workers/transcribe/internal/pipelinestatus/status.go
Normal file
57
workers/transcribe/internal/pipelinestatus/status.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package pipelinestatus
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
amqp "github.com/rabbitmq/amqp091-go"
|
||||
)
|
||||
|
||||
const (
|
||||
StatusPending = "pending"
|
||||
StatusInProgress = "in_progress"
|
||||
StatusDone = "done"
|
||||
StatusError = "error"
|
||||
)
|
||||
|
||||
const (
|
||||
StageQueued = "queued"
|
||||
StageTranscribing = "transcribing"
|
||||
StageAnalysing = "analysing"
|
||||
StageTagging = "tagging"
|
||||
StageCompleted = "completed"
|
||||
)
|
||||
|
||||
type Event struct {
|
||||
TaskID string `json:"task_id"`
|
||||
Filename string `json:"filename,omitempty"`
|
||||
Status string `json:"status"`
|
||||
Stage string `json:"stage"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
}
|
||||
|
||||
func DeclareQueue(ch *amqp.Channel, queue string) error {
|
||||
_, err := ch.QueueDeclare(queue, true, false, false, false, nil)
|
||||
return err
|
||||
}
|
||||
|
||||
func Publish(ctx context.Context, ch *amqp.Channel, queue string, ev Event) error {
|
||||
if ev.Timestamp == 0 {
|
||||
ev.Timestamp = time.Now().Unix()
|
||||
}
|
||||
body, err := json.Marshal(ev)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := ch.PublishWithContext(ctx, "", queue, false, false, amqp.Publishing{
|
||||
ContentType: "application/json",
|
||||
Body: body,
|
||||
DeliveryMode: amqp.Persistent,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("publish status: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
116
workers/transcribe/internal/prompts/prompts_test.go
Normal file
116
workers/transcribe/internal/prompts/prompts_test.go
Normal file
@@ -0,0 +1,116 @@
|
||||
package prompts
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func samplePrompts() []byte {
|
||||
data, err := json.Marshal([]map[string]any{
|
||||
{"id": 1, "id_section": 1, "name": "behavioral", "prompt": "p1", "dt_create": "2026-01-01"},
|
||||
{"id": 2, "id_section": 1, "name": "client_data", "prompt": "p2", "dt_create": "2026-01-01"},
|
||||
{"id": 3, "id_section": 2, "name": "other", "prompt": "p3", "dt_create": "2026-01-01"},
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
func TestLoadStatic(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "prompts.json")
|
||||
if err := os.WriteFile(path, samplePrompts(), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
l := New("static", path, "", "", 1)
|
||||
got, err := l.Load(context.Background())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(got) != 2 {
|
||||
t.Fatalf("want 2 prompts for section 1, got %d", len(got))
|
||||
}
|
||||
if got[0].Name != "behavioral" || got[1].Name != "client_data" {
|
||||
t.Fatalf("unexpected names: %v, %v", got[0].Name, got[1].Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadStaticAllSections(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "prompts.json")
|
||||
if err := os.WriteFile(path, samplePrompts(), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
l := New("static", path, "", "", 0)
|
||||
got, err := l.Load(context.Background())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(got) != 3 {
|
||||
t.Fatalf("want 3 prompts without filter, got %d", len(got))
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadStaticMissingFile(t *testing.T) {
|
||||
l := New("static", filepath.Join(t.TempDir(), "missing.json"), "", "", 1)
|
||||
_, err := l.Load(context.Background())
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadHTTP(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/metrics/" {
|
||||
t.Fatalf("path: got %s", r.URL.Path)
|
||||
}
|
||||
if r.URL.Query().Get("id_section") != "1" {
|
||||
t.Fatalf("id_section: got %s", r.URL.Query().Get("id_section"))
|
||||
}
|
||||
if auth := r.Header.Get("Authorization"); auth != "Bearer test-key" {
|
||||
t.Fatalf("auth: got %q", auth)
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write(samplePrompts())
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
l := New("http", "", srv.URL, "test-key", 1)
|
||||
got, err := l.Load(context.Background())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(got) != 2 {
|
||||
t.Fatalf("want 2 prompts, got %d", len(got))
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadHTTPErrorStatus(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
_, _ = w.Write([]byte("boom"))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
l := New("http", "", srv.URL, "", 1)
|
||||
_, err := l.Load(context.Background())
|
||||
if err == nil {
|
||||
t.Fatal("expected error on 500")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadHTTPMissingBaseURL(t *testing.T) {
|
||||
l := New("http", "", "", "", 1)
|
||||
_, err := l.Load(context.Background())
|
||||
if err == nil {
|
||||
t.Fatal("expected error without base URL")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user