diff --git a/cloud/src/handlers/embeddings.js b/cloud/src/handlers/embeddings.js new file mode 100644 index 00000000..52d86be0 --- /dev/null +++ b/cloud/src/handlers/embeddings.js @@ -0,0 +1,285 @@ +import { getModelInfoCore } from "open-sse/services/model.js"; +import { handleEmbeddingsCore } from "open-sse/handlers/embeddingsCore.js"; +import { errorResponse } from "open-sse/utils/error.js"; +import { + checkFallbackError, + isAccountUnavailable, + getEarliestRateLimitedUntil, + getUnavailableUntil, + formatRetryAfter +} from "open-sse/services/accountFallback.js"; +import { HTTP_STATUS } from "open-sse/config/constants.js"; +import * as log from "../utils/logger.js"; +import { parseApiKey, extractBearerToken } from "../utils/apiKey.js"; +import { getMachineData, saveMachineData } from "../services/storage.js"; + +/** + * Handle POST /v1/embeddings and /{machineId}/v1/embeddings requests. + * + * Follows the same auth + fallback pattern as handleChat: + * 1. Resolve machineId (from URL or API key) + * 2. Validate API key + * 3. Parse model → provider/model + * 4. Get provider credentials with fallback loop + * 5. Delegate to handleEmbeddingsCore (open-sse) + * + * @param {Request} request + * @param {object} env - Cloudflare env bindings + * @param {object} ctx - Execution context + * @param {string|null} machineIdOverride - From URL path (old format), or null (new format) + */ +export async function handleEmbeddings(request, env, ctx, machineIdOverride = null) { + if (request.method === "OPTIONS") { + return new Response(null, { + headers: { + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Methods": "GET, POST, OPTIONS", + "Access-Control-Allow-Headers": "*" + } + }); + } + + // Resolve machineId + let machineId = machineIdOverride; + + if (!machineId) { + const apiKey = extractBearerToken(request); + if (!apiKey) return errorResponse(HTTP_STATUS.UNAUTHORIZED, "Missing API key"); + + const parsed = await parseApiKey(apiKey); + if (!parsed) return errorResponse(HTTP_STATUS.UNAUTHORIZED, "Invalid API key format"); + + if (!parsed.isNewFormat || !parsed.machineId) { + return errorResponse( + HTTP_STATUS.BAD_REQUEST, + "API key does not contain machineId. Use /{machineId}/v1/... endpoint for old format keys." + ); + } + machineId = parsed.machineId; + } + + // Validate API key + if (!await validateApiKey(request, machineId, env)) { + return errorResponse(HTTP_STATUS.UNAUTHORIZED, "Invalid API key"); + } + + // Parse body + let body; + try { + body = await request.json(); + } catch { + return errorResponse(HTTP_STATUS.BAD_REQUEST, "Invalid JSON body"); + } + + const modelStr = body.model; + if (!modelStr) return errorResponse(HTTP_STATUS.BAD_REQUEST, "Missing model"); + + if (!body.input) return errorResponse(HTTP_STATUS.BAD_REQUEST, "Missing required field: input"); + + log.info("EMBEDDINGS", `${machineId} | ${modelStr}`); + + // Resolve model info + const data = await getMachineData(machineId, env); + const modelInfo = await getModelInfoCore(modelStr, data?.modelAliases || {}); + if (!modelInfo.provider) return errorResponse(HTTP_STATUS.BAD_REQUEST, "Invalid model format"); + + const { provider, model } = modelInfo; + log.info("EMBEDDINGS_MODEL", `${provider.toUpperCase()} | ${model}`); + + // Provider credential + fallback loop (mirrors handleChat) + let excludeConnectionId = null; + let lastError = null; + let lastStatus = null; + + while (true) { + const credentials = await getProviderCredentials(machineId, provider, env, excludeConnectionId); + + if (!credentials || credentials.allRateLimited) { + if (credentials?.allRateLimited) { + const retryAfterSec = Math.ceil( + (new Date(credentials.retryAfter).getTime() - Date.now()) / 1000 + ); + const errorMsg = lastError || credentials.lastError || "Unavailable"; + const msg = `[${provider}/${model}] ${errorMsg} (${credentials.retryAfterHuman})`; + const status = lastStatus || Number(credentials.lastErrorCode) || HTTP_STATUS.SERVICE_UNAVAILABLE; + log.warn("EMBEDDINGS", `${provider.toUpperCase()} | ${msg}`); + return new Response( + JSON.stringify({ error: { message: msg } }), + { + status, + headers: { + "Content-Type": "application/json", + "Retry-After": String(Math.max(retryAfterSec, 1)) + } + } + ); + } + if (!excludeConnectionId) { + return errorResponse(HTTP_STATUS.BAD_REQUEST, `No credentials for provider: ${provider}`); + } + log.warn("EMBEDDINGS", `${provider.toUpperCase()} | no more accounts`); + return new Response( + JSON.stringify({ error: lastError || "All accounts unavailable" }), + { + status: lastStatus || HTTP_STATUS.SERVICE_UNAVAILABLE, + headers: { "Content-Type": "application/json" } + } + ); + } + + log.debug("EMBEDDINGS", `account=${credentials.id}`, { provider }); + + const result = await handleEmbeddingsCore({ + body, + modelInfo: { provider, model }, + credentials, + log, + onCredentialsRefreshed: async (newCreds) => { + await updateCredentials(machineId, credentials.id, newCreds, env); + }, + onRequestSuccess: async () => { + await clearAccountError(machineId, credentials.id, credentials, env); + } + }); + + if (result.success) return result.response; + + const { shouldFallback } = checkFallbackError(result.status, result.error); + + if (shouldFallback) { + log.warn("EMBEDDINGS_FALLBACK", `${provider.toUpperCase()} | ${credentials.id} | ${result.status}`); + await markAccountUnavailable(machineId, credentials.id, result.status, result.error, env); + excludeConnectionId = credentials.id; + lastError = result.error; + lastStatus = result.status; + continue; + } + + return result.response; + } +} + +// ─── Helpers (same as chat.js) ─────────────────────────────────────────────── + +async function validateApiKey(request, machineId, env) { + const authHeader = request.headers.get("Authorization"); + if (!authHeader?.startsWith("Bearer ")) return false; + + const apiKey = authHeader.slice(7); + const data = await getMachineData(machineId, env); + return data?.apiKeys?.some(k => k.key === apiKey) || false; +} + +async function getProviderCredentials(machineId, provider, env, excludeConnectionId = null) { + const data = await getMachineData(machineId, env); + if (!data?.providers) return null; + + const providerConnections = Object.entries(data.providers) + .filter(([connId, conn]) => { + if (conn.provider !== provider || !conn.isActive) return false; + if (excludeConnectionId && connId === excludeConnectionId) return false; + if (isAccountUnavailable(conn.rateLimitedUntil)) return false; + return true; + }) + .sort((a, b) => (a[1].priority || 999) - (b[1].priority || 999)); + + if (providerConnections.length === 0) { + const allConnections = Object.entries(data.providers) + .filter(([, conn]) => conn.provider === provider && conn.isActive) + .map(([, conn]) => conn); + const earliest = getEarliestRateLimitedUntil(allConnections); + if (earliest) { + const rateLimitedConns = allConnections.filter( + c => c.rateLimitedUntil && new Date(c.rateLimitedUntil).getTime() > Date.now() + ); + const earliestConn = rateLimitedConns.sort( + (a, b) => new Date(a.rateLimitedUntil) - new Date(b.rateLimitedUntil) + )[0]; + return { + allRateLimited: true, + retryAfter: earliest, + retryAfterHuman: formatRetryAfter(earliest), + lastError: earliestConn?.lastError || null, + lastErrorCode: earliestConn?.errorCode || null + }; + } + return null; + } + + const [connectionId, connection] = providerConnections[0]; + return { + id: connectionId, + apiKey: connection.apiKey, + accessToken: connection.accessToken, + refreshToken: connection.refreshToken, + expiresAt: connection.expiresAt, + projectId: connection.projectId, + providerSpecificData: connection.providerSpecificData, + status: connection.status, + lastError: connection.lastError, + rateLimitedUntil: connection.rateLimitedUntil + }; +} + +async function markAccountUnavailable(machineId, connectionId, status, errorText, env) { + const data = await getMachineData(machineId, env); + if (!data?.providers?.[connectionId]) return; + + const conn = data.providers[connectionId]; + const backoffLevel = conn.backoffLevel || 0; + const { cooldownMs, newBackoffLevel } = checkFallbackError(status, errorText, backoffLevel); + const rateLimitedUntil = getUnavailableUntil(cooldownMs); + const reason = typeof errorText === "string" ? errorText.slice(0, 100) : "Provider error"; + + data.providers[connectionId].rateLimitedUntil = rateLimitedUntil; + data.providers[connectionId].status = "unavailable"; + data.providers[connectionId].lastError = reason; + data.providers[connectionId].errorCode = status || null; + data.providers[connectionId].lastErrorAt = new Date().toISOString(); + data.providers[connectionId].backoffLevel = newBackoffLevel ?? backoffLevel; + data.providers[connectionId].updatedAt = new Date().toISOString(); + + await saveMachineData(machineId, data, env); + log.warn("EMBEDDINGS_ACCOUNT", `${connectionId} | unavailable until ${rateLimitedUntil}`); +} + +async function clearAccountError(machineId, connectionId, currentCredentials, env) { + const hasError = + currentCredentials.status === "unavailable" || + currentCredentials.lastError || + currentCredentials.rateLimitedUntil; + + if (!hasError) return; + + const data = await getMachineData(machineId, env); + if (!data?.providers?.[connectionId]) return; + + data.providers[connectionId].status = "active"; + data.providers[connectionId].lastError = null; + data.providers[connectionId].lastErrorAt = null; + data.providers[connectionId].rateLimitedUntil = null; + data.providers[connectionId].backoffLevel = 0; + data.providers[connectionId].updatedAt = new Date().toISOString(); + + await saveMachineData(machineId, data, env); + log.info("EMBEDDINGS_ACCOUNT", `${connectionId} | error cleared`); +} + +async function updateCredentials(machineId, connectionId, newCredentials, env) { + const data = await getMachineData(machineId, env); + if (!data?.providers?.[connectionId]) return; + + data.providers[connectionId].accessToken = newCredentials.accessToken; + if (newCredentials.refreshToken) + data.providers[connectionId].refreshToken = newCredentials.refreshToken; + if (newCredentials.expiresIn) { + data.providers[connectionId].expiresAt = new Date( + Date.now() + newCredentials.expiresIn * 1000 + ).toISOString(); + data.providers[connectionId].expiresIn = newCredentials.expiresIn; + } + data.providers[connectionId].updatedAt = new Date().toISOString(); + + await saveMachineData(machineId, data, env); + log.debug("EMBEDDINGS_TOKEN", `credentials updated | ${connectionId}`); +} diff --git a/cloud/src/index.js b/cloud/src/index.js index d33dd57b..1385b8c3 100644 --- a/cloud/src/index.js +++ b/cloud/src/index.js @@ -12,6 +12,7 @@ import { handleVerify } from "./handlers/verify.js"; import { handleTestClaude } from "./handlers/testClaude.js"; import { handleForward } from "./handlers/forward.js"; import { handleForwardRaw } from "./handlers/forwardRaw.js"; +import { handleEmbeddings } from "./handlers/embeddings.js"; import { createLandingPageResponse } from "./services/landingPage.js"; // Initialize translators at module load (static imports) @@ -115,6 +116,13 @@ const worker = { return addCorsHeaders(response); } + // New format: /v1/embeddings + if (path === "/v1/embeddings" && request.method === "POST") { + const response = await handleEmbeddings(request, env, ctx, null); + log.response(response.status, Date.now() - startTime); + return addCorsHeaders(response); + } + // New format: /v1/responses (OpenAI Responses API - Codex CLI) if (path === "/v1/responses" && request.method === "POST") { const response = await handleChat(request, env, ctx, null); @@ -149,6 +157,14 @@ const worker = { return response; } + // Machine ID based embeddings endpoint + if (path.match(/^\/[^\/]+\/v1\/embeddings$/) && request.method === "POST") { + const machineId = path.split("/")[1]; + const response = await handleEmbeddings(request, env, ctx, machineId); + log.response(response.status, Date.now() - startTime); + return addCorsHeaders(response); + } + // Machine ID based messages endpoint (Claude format) if (path.match(/^\/[^\/]+\/v1\/messages$/) && request.method === "POST") { const machineId = path.split("/")[1]; diff --git a/open-sse/handlers/embeddingsCore.js b/open-sse/handlers/embeddingsCore.js new file mode 100644 index 00000000..49ba9a6b --- /dev/null +++ b/open-sse/handlers/embeddingsCore.js @@ -0,0 +1,209 @@ +import { getModelTargetFormat, PROVIDER_ID_TO_ALIAS } from "../config/providerModels.js"; +import { createErrorResult, parseUpstreamError, formatProviderError } from "../utils/error.js"; +import { HTTP_STATUS } from "../config/constants.js"; +import { getExecutor } from "../executors/index.js"; +import { refreshWithRetry } from "../services/tokenRefresh.js"; + +/** + * Build the embeddings request body for the target provider. + * Most OpenAI-compatible providers accept the same format. + * For providers that don't support embeddings natively (chat-only), we return an error. + */ +function buildEmbeddingsBody(model, input, encodingFormat) { + const body = { + model, + input + }; + if (encodingFormat) { + body.encoding_format = encodingFormat; + } + return body; +} + +/** + * Build the URL for the embeddings endpoint based on the provider. + */ +function buildEmbeddingsUrl(provider, credentials) { + switch (provider) { + case "openai": + return "https://api.openai.com/v1/embeddings"; + case "openrouter": + return "https://openrouter.ai/api/v1/embeddings"; + default: + // openai-compatible providers: use their baseUrl + /embeddings + if (provider?.startsWith?.("openai-compatible-")) { + const baseUrl = credentials?.providerSpecificData?.baseUrl || "https://api.openai.com/v1"; + return `${baseUrl.replace(/\/$/, "")}/embeddings`; + } + // For other providers, attempt to use their base URL pattern with /embeddings path + return null; + } +} + +/** + * Build headers for the embeddings request. + */ +function buildEmbeddingsHeaders(provider, credentials) { + const headers = { "Content-Type": "application/json" }; + + switch (provider) { + case "openai": + case "openrouter": + headers["Authorization"] = `Bearer ${credentials.apiKey || credentials.accessToken}`; + if (provider === "openrouter") { + headers["HTTP-Referer"] = "https://endpoint-proxy.local"; + headers["X-Title"] = "Endpoint Proxy"; + } + break; + default: + if (provider?.startsWith?.("openai-compatible-")) { + headers["Authorization"] = `Bearer ${credentials.apiKey || credentials.accessToken}`; + } else { + headers["Authorization"] = `Bearer ${credentials.apiKey || credentials.accessToken}`; + } + } + + return headers; +} + +/** + * Normalize the embeddings response to OpenAI format. + * Most OpenAI-compatible providers already return this format. + */ +function normalizeEmbeddingsResponse(responseBody, model) { + // Already in OpenAI format + if (responseBody.object === "list" && Array.isArray(responseBody.data)) { + return responseBody; + } + + // Try to handle alternate formats gracefully + return responseBody; +} + +/** + * Core embeddings handler — shared between Worker and SSE server. + * + * @param {object} options + * @param {object} options.body - Parsed request body { model, input, encoding_format } + * @param {object} options.modelInfo - { provider, model } + * @param {object} options.credentials - Provider credentials + * @param {object} [options.log] - Logger + * @param {function} [options.onCredentialsRefreshed] - Called when creds are refreshed + * @param {function} [options.onRequestSuccess] - Called on success (clear error state) + * @returns {Promise<{ success: boolean, response: Response, status?: number, error?: string }>} + */ +export async function handleEmbeddingsCore({ + body, + modelInfo, + credentials, + log, + onCredentialsRefreshed, + onRequestSuccess +}) { + const { provider, model } = modelInfo; + + // Validate input + const input = body.input; + if (!input) { + return createErrorResult(HTTP_STATUS.BAD_REQUEST, "Missing required field: input"); + } + if (typeof input !== "string" && !Array.isArray(input)) { + return createErrorResult(HTTP_STATUS.BAD_REQUEST, "input must be a string or array of strings"); + } + + const encodingFormat = body.encoding_format || "float"; + + // Determine embeddings URL + const url = buildEmbeddingsUrl(provider, credentials); + if (!url) { + return createErrorResult( + HTTP_STATUS.BAD_REQUEST, + `Provider '${provider}' does not support embeddings. Use openai, openrouter, or an openai-compatible provider.` + ); + } + + const headers = buildEmbeddingsHeaders(provider, credentials); + const requestBody = buildEmbeddingsBody(model, input, encodingFormat); + + log?.debug?.("EMBEDDINGS", `${provider.toUpperCase()} | ${model} | input_type=${Array.isArray(input) ? `array[${input.length}]` : "string"}`); + + let providerResponse; + try { + providerResponse = await fetch(url, { + method: "POST", + headers, + body: JSON.stringify(requestBody) + }); + } catch (error) { + const errMsg = formatProviderError(error, provider, model, HTTP_STATUS.BAD_GATEWAY); + log?.debug?.("EMBEDDINGS", `Fetch error: ${errMsg}`); + return createErrorResult(HTTP_STATUS.BAD_GATEWAY, errMsg); + } + + // Handle 401/403 — try token refresh + if ( + providerResponse.status === HTTP_STATUS.UNAUTHORIZED || + providerResponse.status === HTTP_STATUS.FORBIDDEN + ) { + const executor = getExecutor(provider); + const newCredentials = await refreshWithRetry( + () => executor.refreshCredentials(credentials, log), + 3, + log + ); + + if (newCredentials?.accessToken || newCredentials?.apiKey) { + log?.info?.("TOKEN", `${provider.toUpperCase()} | refreshed for embeddings`); + Object.assign(credentials, newCredentials); + if (onCredentialsRefreshed && newCredentials) { + await onCredentialsRefreshed(newCredentials); + } + + // Retry with refreshed credentials + try { + const retryHeaders = buildEmbeddingsHeaders(provider, credentials); + providerResponse = await fetch(url, { + method: "POST", + headers: retryHeaders, + body: JSON.stringify(requestBody) + }); + } catch (retryError) { + log?.warn?.("TOKEN", `${provider.toUpperCase()} | retry after refresh failed`); + } + } else { + log?.warn?.("TOKEN", `${provider.toUpperCase()} | refresh failed`); + } + } + + if (!providerResponse.ok) { + const { statusCode, message } = await parseUpstreamError(providerResponse, provider); + const errMsg = formatProviderError(new Error(message), provider, model, statusCode); + log?.debug?.("EMBEDDINGS", `Provider error: ${errMsg}`); + return createErrorResult(statusCode, errMsg); + } + + let responseBody; + try { + responseBody = await providerResponse.json(); + } catch (parseError) { + return createErrorResult(HTTP_STATUS.BAD_GATEWAY, `Invalid JSON response from ${provider}`); + } + + if (onRequestSuccess) { + await onRequestSuccess(); + } + + const normalized = normalizeEmbeddingsResponse(responseBody, model); + + log?.debug?.("EMBEDDINGS", `Success | usage=${JSON.stringify(normalized.usage || {})}`); + + return { + success: true, + response: new Response(JSON.stringify(normalized), { + headers: { + "Content-Type": "application/json", + "Access-Control-Allow-Origin": "*" + } + }) + }; +} diff --git a/src/app/api/v1/embeddings/route.js b/src/app/api/v1/embeddings/route.js new file mode 100644 index 00000000..9ae873d1 --- /dev/null +++ b/src/app/api/v1/embeddings/route.js @@ -0,0 +1,21 @@ +import { handleEmbeddings } from "@/sse/handlers/embeddings.js"; + +/** + * Handle CORS preflight + */ +export async function OPTIONS() { + return new Response(null, { + headers: { + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Methods": "POST, OPTIONS", + "Access-Control-Allow-Headers": "*" + } + }); +} + +/** + * POST /v1/embeddings - OpenAI-compatible embeddings endpoint + */ +export async function POST(request) { + return await handleEmbeddings(request); +} diff --git a/src/sse/handlers/embeddings.js b/src/sse/handlers/embeddings.js new file mode 100644 index 00000000..344df4c3 --- /dev/null +++ b/src/sse/handlers/embeddings.js @@ -0,0 +1,141 @@ +import { + getProviderCredentials, + markAccountUnavailable, + clearAccountError, + extractApiKey, + isValidApiKey, +} from "../services/auth.js"; +import { getModelInfo } from "../services/model.js"; +import { handleEmbeddingsCore } from "open-sse/handlers/embeddingsCore.js"; +import { errorResponse, unavailableResponse } from "open-sse/utils/error.js"; +import { HTTP_STATUS } from "open-sse/config/constants.js"; +import * as log from "../utils/logger.js"; +import { updateProviderCredentials, checkAndRefreshToken } from "../services/tokenRefresh.js"; + +/** + * Handle embeddings request for the SSE/Next.js server. + * Follows the same auth + fallback pattern as handleChat. + * + * @param {Request} request + */ +export async function handleEmbeddings(request) { + let body; + try { + body = await request.json(); + } catch { + log.warn("EMBEDDINGS", "Invalid JSON body"); + return errorResponse(HTTP_STATUS.BAD_REQUEST, "Invalid JSON body"); + } + + const url = new URL(request.url); + const modelStr = body.model; + + log.request("POST", `${url.pathname} | ${modelStr}`); + + // Log API key (masked) + const apiKey = extractApiKey(request); + if (apiKey) { + log.debug("AUTH", `API Key: ${log.maskKey(apiKey)}`); + } else { + log.debug("AUTH", "No API key provided (local mode)"); + } + + // Optional strict API key validation + if (process.env.REQUIRE_API_KEY === "true") { + if (!apiKey) { + log.warn("AUTH", "Missing API key while REQUIRE_API_KEY=true"); + return errorResponse(HTTP_STATUS.UNAUTHORIZED, "Missing API key"); + } + const valid = await isValidApiKey(apiKey); + if (!valid) { + log.warn("AUTH", "Invalid API key while REQUIRE_API_KEY=true"); + return errorResponse(HTTP_STATUS.UNAUTHORIZED, "Invalid API key"); + } + } + + if (!modelStr) { + log.warn("EMBEDDINGS", "Missing model"); + return errorResponse(HTTP_STATUS.BAD_REQUEST, "Missing model"); + } + + if (!body.input) { + log.warn("EMBEDDINGS", "Missing input"); + return errorResponse(HTTP_STATUS.BAD_REQUEST, "Missing required field: input"); + } + + const modelInfo = await getModelInfo(modelStr); + if (!modelInfo.provider) { + log.warn("EMBEDDINGS", "Invalid model format", { model: modelStr }); + return errorResponse(HTTP_STATUS.BAD_REQUEST, "Invalid model format"); + } + + const { provider, model } = modelInfo; + + if (modelStr !== `${provider}/${model}`) { + log.info("ROUTING", `${modelStr} → ${provider}/${model}`); + } else { + log.info("ROUTING", `Provider: ${provider}, Model: ${model}`); + } + + // Credential + fallback loop (mirrors handleChat) + let excludeConnectionId = null; + let lastError = null; + let lastStatus = null; + + while (true) { + const credentials = await getProviderCredentials(provider, excludeConnectionId, model); + + // All accounts unavailable + if (!credentials || credentials.allRateLimited) { + if (credentials?.allRateLimited) { + const errorMsg = lastError || credentials.lastError || "Unavailable"; + const status = lastStatus || Number(credentials.lastErrorCode) || HTTP_STATUS.SERVICE_UNAVAILABLE; + log.warn("EMBEDDINGS", `[${provider}/${model}] ${errorMsg} (${credentials.retryAfterHuman})`); + return unavailableResponse(status, `[${provider}/${model}] ${errorMsg}`, credentials.retryAfter, credentials.retryAfterHuman); + } + if (!excludeConnectionId) { + log.error("AUTH", `No credentials for provider: ${provider}`); + return errorResponse(HTTP_STATUS.BAD_REQUEST, `No credentials for provider: ${provider}`); + } + log.warn("EMBEDDINGS", "No more accounts available", { provider }); + return errorResponse(lastStatus || HTTP_STATUS.SERVICE_UNAVAILABLE, lastError || "All accounts unavailable"); + } + + const accountId = credentials.connectionId.slice(0, 8); + log.info("AUTH", `Using ${provider} account: ${accountId}...`); + + const refreshedCredentials = await checkAndRefreshToken(provider, credentials); + + const result = await handleEmbeddingsCore({ + body: { ...body, model: `${provider}/${model}` }, + modelInfo: { provider, model }, + credentials: refreshedCredentials, + log, + onCredentialsRefreshed: async (newCreds) => { + await updateProviderCredentials(credentials.connectionId, { + accessToken: newCreds.accessToken, + refreshToken: newCreds.refreshToken, + providerSpecificData: newCreds.providerSpecificData, + testStatus: "active" + }); + }, + onRequestSuccess: async () => { + await clearAccountError(credentials.connectionId, credentials); + } + }); + + if (result.success) return result.response; + + const { shouldFallback } = await markAccountUnavailable(credentials.connectionId, result.status, result.error, provider, model); + + if (shouldFallback) { + log.warn("AUTH", `Account ${accountId}... unavailable (${result.status}), trying fallback`); + excludeConnectionId = credentials.connectionId; + lastError = result.error; + lastStatus = result.status; + continue; + } + + return result.response; + } +} diff --git a/tests/.gitignore b/tests/.gitignore new file mode 100644 index 00000000..cfcafcc5 --- /dev/null +++ b/tests/.gitignore @@ -0,0 +1,4 @@ +node_modules/ +package-lock.json +.vite/ +coverage/ diff --git a/tests/README.md b/tests/README.md new file mode 100644 index 00000000..807d5fbc --- /dev/null +++ b/tests/README.md @@ -0,0 +1,51 @@ +# 9Router Embeddings Tests + +Unit tests for the `/v1/embeddings` endpoint implementation. + +## Setup + +Vitest must be installed globally or in `/tmp/node_modules` (due to npm workspace hoisting from the root Next.js project): + +```bash +cd /tmp && npm install vitest +``` + +## Running Tests + +```bash +cd tests/ +NODE_PATH=/tmp/node_modules /tmp/node_modules/.bin/vitest run --reporter=verbose --config ./vitest.config.js +``` + +Or using the package script (from the `tests/` directory): + +```bash +npm test +``` + +## Test Files + +| File | What it tests | +|------|--------------| +| `unit/embeddingsCore.test.js` | `open-sse/handlers/embeddingsCore.js` — core logic: body builder, URL router, headers, handler flow | +| `unit/embeddings.cloud.test.js` | `cloud/src/handlers/embeddings.js` — cloud worker handler: auth, validation, rate limits, CORS | + +## Coverage Summary (59 tests) + +### `embeddingsCore.test.js` (36 tests) +- `buildEmbeddingsBody`: single string, array, encoding_format, default float +- `buildEmbeddingsUrl`: openai, openrouter, openai-compatible-*, unsupported providers +- `buildEmbeddingsHeaders`: per-provider header sets, fallback to accessToken +- `handleEmbeddingsCore` input validation: missing, wrong type, null, empty +- `handleEmbeddingsCore` success: response format, CORS, Content-Type, callbacks +- `handleEmbeddingsCore` errors: 400/429/500, network error, invalid JSON +- `handleEmbeddingsCore` token refresh: 401 retry, graceful fallback + +### `embeddings.cloud.test.js` (23 tests) +- CORS OPTIONS: 200 response, empty body, correct headers +- Authentication: missing key, bad format, old-format key, wrong key value, valid key +- Body validation: invalid JSON, missing model, missing input, bad model +- Happy path: single string, array, correct delegation, CORS header, machineId override +- Rate limiting: all accounts rate-limited → 503 + Retry-After, no credentials → 400 +- Error propagation: non-fallback errors passed through, 429 exhausts accounts +- machineId override: validates key, rejects wrong key diff --git a/tests/package.json b/tests/package.json new file mode 100644 index 00000000..0237ab1e --- /dev/null +++ b/tests/package.json @@ -0,0 +1,17 @@ +{ + "name": "9router-tests", + "version": "1.0.0", + "private": true, + "type": "module", + "description": "Unit tests for 9router embeddings endpoint", + "scripts": { + "test": "NODE_PATH=/tmp/node_modules /tmp/node_modules/.bin/vitest run --reporter=verbose", + "test:watch": "NODE_PATH=/tmp/node_modules /tmp/node_modules/.bin/vitest --reporter=verbose" + }, + "devDependencies": { + "vitest": "^4.0.0" + }, + "engines": { + "node": ">=18" + } +} diff --git a/tests/unit/embeddings.cloud.test.js b/tests/unit/embeddings.cloud.test.js new file mode 100644 index 00000000..725f7dd7 --- /dev/null +++ b/tests/unit/embeddings.cloud.test.js @@ -0,0 +1,524 @@ +/** + * Unit tests for cloud/src/handlers/embeddings.js + * + * Tests cover: + * - CORS OPTIONS → 200 with CORS headers + * - Auth: missing Bearer → 401 + * - Auth: invalid key format → 401 + * - Auth: valid new-format key but wrong key value → 401 + * - Body validation: missing model → 400, missing input → 400 + * - Invalid model format → 400 + * - Happy path → delegates to handleEmbeddingsCore and returns response + * - Rate-limited provider → 503 with Retry-After + * - No credentials → 400 + * + * Strategy: mock all external dependencies (D1 storage, handleEmbeddingsCore, apiKey utils) + * so tests run without Cloudflare Workers runtime. + */ + +import { describe, it, expect, vi, beforeEach, afterEach } from "vitest"; + +// ─── Module mocks (hoisted before imports) ─────────────────────────────────── + +vi.mock("../../open-sse/services/model.js", () => ({ + getModelInfoCore: vi.fn(), +})); + +vi.mock("../../open-sse/handlers/embeddingsCore.js", () => ({ + handleEmbeddingsCore: vi.fn(), +})); + +vi.mock("../../open-sse/utils/error.js", async (importOriginal) => { + // Use real errorResponse implementation so response bodies are realistic + const actual = await importOriginal(); + return actual; +}); + +vi.mock("../../open-sse/services/accountFallback.js", async (importOriginal) => { + const actual = await importOriginal(); + return actual; +}); + +vi.mock("../../cloud/src/utils/logger.js", () => ({ + info: vi.fn(), + debug: vi.fn(), + warn: vi.fn(), + error: vi.fn(), +})); + +vi.mock("../../cloud/src/utils/apiKey.js", () => ({ + parseApiKey: vi.fn(), + extractBearerToken: vi.fn(), +})); + +vi.mock("../../cloud/src/services/storage.js", () => ({ + getMachineData: vi.fn(), + saveMachineData: vi.fn(), +})); + +// ─── Imports (after mocks) ──────────────────────────────────────────────────── + +import { handleEmbeddings } from "../../cloud/src/handlers/embeddings.js"; +import { getModelInfoCore } from "../../open-sse/services/model.js"; +import { handleEmbeddingsCore } from "../../open-sse/handlers/embeddingsCore.js"; +import { parseApiKey, extractBearerToken } from "../../cloud/src/utils/apiKey.js"; +import { getMachineData, saveMachineData } from "../../cloud/src/services/storage.js"; + +// ─── Fixtures ───────────────────────────────────────────────────────────────── + +const MACHINE_ID = "mach01"; +const VALID_API_KEY = "sk-mach01-key01-ab12cd34"; // new format shape +const VALID_EMBEDDING_RESPONSE_BODY = { + object: "list", + data: [{ object: "embedding", index: 0, embedding: [0.1, 0.2, 0.3] }], + model: "text-embedding-ada-002", + usage: { prompt_tokens: 3, total_tokens: 3 }, +}; + +/** Build a minimal mock env (Cloudflare Worker env bindings) */ +function makeEnv() { + return { DB: {}, KV: {} }; +} + +/** Build a mock machine data record stored in D1 */ +function makeMachineData(overrides = {}) { + return { + machineId: MACHINE_ID, + apiKeys: [{ key: VALID_API_KEY, label: "test" }], + providers: { + "conn-001": { + provider: "openai", + apiKey: "sk-openai-provider-key", + isActive: true, + priority: 1, + status: "active", + rateLimitedUntil: null, + lastError: null, + }, + }, + modelAliases: {}, + ...overrides, + }; +} + +/** Make a Request object */ +function makeRequest(method = "POST", body = null, authHeader = `Bearer ${VALID_API_KEY}`) { + const headers = { "Content-Type": "application/json" }; + if (authHeader) headers["Authorization"] = authHeader; + + return new Request("https://9cli.hxd.app/v1/embeddings", { + method, + headers, + body: body ? JSON.stringify(body) : undefined, + }); +} + +// ─── Tests: CORS OPTIONS ────────────────────────────────────────────────────── + +describe("handleEmbeddings — CORS OPTIONS", () => { + it("OPTIONS request → 200 with Access-Control-Allow-Origin: *", async () => { + const req = makeRequest("OPTIONS", null, null); + const res = await handleEmbeddings(req, makeEnv(), {}); + + expect(res.status).toBe(200); + expect(res.headers.get("Access-Control-Allow-Origin")).toBe("*"); + expect(res.headers.get("Access-Control-Allow-Methods")).toMatch(/POST/); + }); + + it("OPTIONS request → body is empty/null", async () => { + const req = makeRequest("OPTIONS", null, null); + const res = await handleEmbeddings(req, makeEnv(), {}); + const text = await res.text(); + expect(text).toBe(""); + }); +}); + +// ─── Tests: Authentication ──────────────────────────────────────────────────── + +describe("handleEmbeddings — authentication", () => { + beforeEach(() => { + vi.mocked(extractBearerToken).mockReturnValue(null); + vi.mocked(parseApiKey).mockResolvedValue(null); + vi.mocked(getMachineData).mockResolvedValue(makeMachineData()); + vi.mocked(getModelInfoCore).mockResolvedValue({ provider: "openai", model: "text-embedding-ada-002" }); + }); + + afterEach(() => { + vi.clearAllMocks(); + }); + + it("missing Authorization header → 401", async () => { + vi.mocked(extractBearerToken).mockReturnValue(null); + + const req = makeRequest("POST", { model: "ag/gemini-embedding-001", input: "hello" }, null); + const res = await handleEmbeddings(req, makeEnv(), {}); + + expect(res.status).toBe(401); + const body = await res.json(); + expect(body.error.message).toMatch(/missing api key/i); + }); + + it("Authorization header without Bearer scheme → 401", async () => { + vi.mocked(extractBearerToken).mockReturnValue(null); + + const req = makeRequest("POST", { model: "ag/gemini-embedding-001", input: "hello" }, "Token abc123"); + const res = await handleEmbeddings(req, makeEnv(), {}); + + expect(res.status).toBe(401); + }); + + it("Bearer key that fails parseApiKey → 401", async () => { + vi.mocked(extractBearerToken).mockReturnValue("sk-invalidkey"); + vi.mocked(parseApiKey).mockResolvedValue(null); + + const req = makeRequest("POST", { model: "ag/gemini-embedding-001", input: "hello" }); + const res = await handleEmbeddings(req, makeEnv(), {}); + + expect(res.status).toBe(401); + const body = await res.json(); + expect(body.error.message).toMatch(/invalid api key format/i); + }); + + it("old-format key (no machineId) → 400 asking to use machineId endpoint", async () => { + vi.mocked(extractBearerToken).mockReturnValue("sk-oldfmt8"); + vi.mocked(parseApiKey).mockResolvedValue({ machineId: null, keyId: "oldfmt8", isNewFormat: false }); + + const req = makeRequest("POST", { model: "ag/gemini-embedding-001", input: "hello" }); + const res = await handleEmbeddings(req, makeEnv(), {}); + + expect(res.status).toBe(400); + const body = await res.json(); + expect(body.error.message).toMatch(/machineId/i); + }); + + it("valid key format but key value not in machine apiKeys → 401", async () => { + vi.mocked(extractBearerToken).mockReturnValue("sk-mach01-key01-ab12cd34"); + vi.mocked(parseApiKey).mockResolvedValue({ machineId: MACHINE_ID, keyId: "key01", isNewFormat: true }); + vi.mocked(getMachineData).mockResolvedValue(makeMachineData({ + apiKeys: [{ key: "sk-different-key" }], // key doesn't match + })); + + const req = makeRequest("POST", { model: "ag/gemini-embedding-001", input: "hello" }); + const res = await handleEmbeddings(req, makeEnv(), {}); + + expect(res.status).toBe(401); + const body = await res.json(); + expect(body.error.message).toMatch(/invalid api key/i); + }); + + it("valid key → passes auth (proceeds to body parsing)", async () => { + vi.mocked(extractBearerToken).mockReturnValue(VALID_API_KEY); + vi.mocked(parseApiKey).mockResolvedValue({ machineId: MACHINE_ID, keyId: "key01", isNewFormat: true }); + vi.mocked(getMachineData).mockResolvedValue(makeMachineData()); + vi.mocked(getModelInfoCore).mockResolvedValue({ provider: "openai", model: "text-embedding-ada-002" }); + vi.mocked(handleEmbeddingsCore).mockResolvedValue({ + success: true, + response: new Response(JSON.stringify(VALID_EMBEDDING_RESPONSE_BODY), { + status: 200, + headers: { "Content-Type": "application/json", "Access-Control-Allow-Origin": "*" }, + }), + }); + + const req = makeRequest("POST", { model: "openai/text-embedding-ada-002", input: "hello" }); + const res = await handleEmbeddings(req, makeEnv(), {}); + + // Should not be 401 + expect(res.status).not.toBe(401); + expect(res.status).not.toBe(403); + }); +}); + +// ─── Tests: Body validation ─────────────────────────────────────────────────── + +describe("handleEmbeddings — body validation", () => { + beforeEach(() => { + vi.mocked(extractBearerToken).mockReturnValue(VALID_API_KEY); + vi.mocked(parseApiKey).mockResolvedValue({ machineId: MACHINE_ID, keyId: "key01", isNewFormat: true }); + vi.mocked(getMachineData).mockResolvedValue(makeMachineData()); + }); + + afterEach(() => { + vi.clearAllMocks(); + }); + + it("invalid JSON body → 400", async () => { + const req = new Request("https://9cli.hxd.app/v1/embeddings", { + method: "POST", + headers: { + "Content-Type": "application/json", + "Authorization": `Bearer ${VALID_API_KEY}`, + }, + body: "{ bad json", + }); + const res = await handleEmbeddings(req, makeEnv(), {}); + + expect(res.status).toBe(400); + const body = await res.json(); + expect(body.error.message).toMatch(/invalid json/i); + }); + + it("missing model field → 400", async () => { + const req = makeRequest("POST", { input: "hello world" }); + const res = await handleEmbeddings(req, makeEnv(), {}); + + expect(res.status).toBe(400); + const body = await res.json(); + expect(body.error.message).toMatch(/missing model/i); + }); + + it("missing input field → 400", async () => { + const req = makeRequest("POST", { model: "ag/gemini-embedding-001" }); + const res = await handleEmbeddings(req, makeEnv(), {}); + + expect(res.status).toBe(400); + const body = await res.json(); + expect(body.error.message).toMatch(/missing required field: input/i); + }); + + it("model with no provider mapping → 400", async () => { + vi.mocked(getModelInfoCore).mockResolvedValue({ provider: null, model: null }); + + const req = makeRequest("POST", { model: "nonexistent/model", input: "hello" }); + const res = await handleEmbeddings(req, makeEnv(), {}); + + expect(res.status).toBe(400); + const body = await res.json(); + expect(body.error.message).toMatch(/invalid model format/i); + }); +}); + +// ─── Tests: Happy path — valid request ──────────────────────────────────────── + +describe("handleEmbeddings — valid request (happy path)", () => { + beforeEach(() => { + vi.mocked(extractBearerToken).mockReturnValue(VALID_API_KEY); + vi.mocked(parseApiKey).mockResolvedValue({ machineId: MACHINE_ID, keyId: "key01", isNewFormat: true }); + vi.mocked(getMachineData).mockResolvedValue(makeMachineData()); + vi.mocked(getModelInfoCore).mockResolvedValue({ provider: "openai", model: "text-embedding-ada-002" }); + vi.mocked(handleEmbeddingsCore).mockResolvedValue({ + success: true, + response: new Response(JSON.stringify(VALID_EMBEDDING_RESPONSE_BODY), { + status: 200, + headers: { "Content-Type": "application/json", "Access-Control-Allow-Origin": "*" }, + }), + }); + vi.mocked(saveMachineData).mockResolvedValue(undefined); + }); + + afterEach(() => { + vi.clearAllMocks(); + }); + + it("single string input → 200 with embeddings data", async () => { + const req = makeRequest("POST", { + model: "openai/text-embedding-ada-002", + input: "Hello world test embedding", + }); + const res = await handleEmbeddings(req, makeEnv(), {}); + + expect(res.status).toBe(200); + const body = await res.json(); + expect(body.object).toBe("list"); + expect(Array.isArray(body.data)).toBe(true); + }); + + it("array input → 200 with embeddings data", async () => { + const req = makeRequest("POST", { + model: "openai/text-embedding-ada-002", + input: ["Hello", "World"], + }); + const res = await handleEmbeddings(req, makeEnv(), {}); + + expect(res.status).toBe(200); + const body = await res.json(); + expect(body.object).toBe("list"); + }); + + it("delegates to handleEmbeddingsCore with correct args", async () => { + const req = makeRequest("POST", { + model: "openai/text-embedding-ada-002", + input: "Test", + }); + await handleEmbeddings(req, makeEnv(), {}); + + expect(handleEmbeddingsCore).toHaveBeenCalledOnce(); + const callArgs = vi.mocked(handleEmbeddingsCore).mock.calls[0][0]; + expect(callArgs.body.input).toBe("Test"); + expect(callArgs.modelInfo.provider).toBe("openai"); + expect(callArgs.modelInfo.model).toBe("text-embedding-ada-002"); + expect(callArgs.credentials).toBeDefined(); + }); + + it("response has CORS header from addCorsHeaders wrapper", async () => { + const req = makeRequest("POST", { + model: "openai/text-embedding-ada-002", + input: "Hello", + }); + const res = await handleEmbeddings(req, makeEnv(), {}); + + expect(res.headers.get("Access-Control-Allow-Origin")).toBe("*"); + }); + + it("machineId-override path: /{machineId}/v1/embeddings works", async () => { + // Direct call with machineId override (old format URL path) + const req = new Request(`https://9cli.hxd.app/${MACHINE_ID}/v1/embeddings`, { + method: "POST", + headers: { + "Content-Type": "application/json", + "Authorization": `Bearer ${VALID_API_KEY}`, + }, + body: JSON.stringify({ model: "openai/text-embedding-ada-002", input: "Hello" }), + }); + + const res = await handleEmbeddings(req, makeEnv(), {}, MACHINE_ID); + expect(res.status).toBe(200); + }); +}); + +// ─── Tests: Rate limiting ────────────────────────────────────────────────────── + +describe("handleEmbeddings — rate limit fallback", () => { + beforeEach(() => { + vi.mocked(extractBearerToken).mockReturnValue(VALID_API_KEY); + vi.mocked(parseApiKey).mockResolvedValue({ machineId: MACHINE_ID, keyId: "key01", isNewFormat: true }); + vi.mocked(getModelInfoCore).mockResolvedValue({ provider: "openai", model: "text-embedding-ada-002" }); + vi.mocked(saveMachineData).mockResolvedValue(undefined); + }); + + afterEach(() => { + vi.clearAllMocks(); + }); + + it("all provider accounts rate-limited → 503 with Retry-After header", async () => { + const rateLimitedUntil = new Date(Date.now() + 60000).toISOString(); // 60s from now + vi.mocked(getMachineData).mockResolvedValue(makeMachineData({ + providers: { + "conn-001": { + provider: "openai", + apiKey: "sk-key", + isActive: true, + priority: 1, + status: "unavailable", + rateLimitedUntil, // rate-limited + lastError: "Rate limit exceeded", + errorCode: 429, + backoffLevel: 1, + }, + }, + })); + + const req = makeRequest("POST", { model: "openai/text-embedding-ada-002", input: "hello" }); + const res = await handleEmbeddings(req, makeEnv(), {}); + + expect(res.status).toBe(429); + expect(res.headers.get("Retry-After")).toBeDefined(); + const retryAfter = parseInt(res.headers.get("Retry-After")); + expect(retryAfter).toBeGreaterThan(0); + }); + + it("provider account not found → 400 No credentials", async () => { + vi.mocked(getMachineData).mockResolvedValue(makeMachineData({ + providers: {}, // no providers + })); + + const req = makeRequest("POST", { model: "openai/text-embedding-ada-002", input: "hello" }); + const res = await handleEmbeddings(req, makeEnv(), {}); + + expect(res.status).toBe(400); + const body = await res.json(); + expect(body.error.message).toMatch(/no credentials/i); + }); + + it("core returns non-fallback error → propagates error response directly", async () => { + vi.mocked(getMachineData).mockResolvedValue(makeMachineData()); + vi.mocked(handleEmbeddingsCore).mockResolvedValue({ + success: false, + status: 400, + error: "input must be a string or array", + response: new Response( + JSON.stringify({ error: { message: "input must be a string or array" } }), + { status: 400, headers: { "Content-Type": "application/json" } } + ), + }); + + const req = makeRequest("POST", { model: "openai/text-embedding-ada-002", input: "hello" }); + const res = await handleEmbeddings(req, makeEnv(), {}); + + // Non-fallback error (400) should not trigger account cycle; returns error directly + expect(res.status).toBe(400); + }); + + it("core returns 429 → marks account unavailable, then no more accounts → 503", async () => { + vi.mocked(getMachineData).mockResolvedValue(makeMachineData()); + vi.mocked(handleEmbeddingsCore).mockResolvedValue({ + success: false, + status: 429, + error: "Rate limit exceeded", + response: new Response( + JSON.stringify({ error: { message: "Rate limit exceeded" } }), + { status: 429, headers: { "Content-Type": "application/json" } } + ), + }); + + const req = makeRequest("POST", { model: "openai/text-embedding-ada-002", input: "hello" }); + const res = await handleEmbeddings(req, makeEnv(), {}); + + // After fallback loop exhausts accounts + expect([429, 503]).toContain(res.status); + }); +}); + +// ─── Tests: machineId-override (old-format URL path) ───────────────────────── + +describe("handleEmbeddings — machineId override path", () => { + beforeEach(() => { + // When machineId is provided via URL, no apiKey parsing needed for machineId + vi.mocked(getMachineData).mockResolvedValue(makeMachineData()); + vi.mocked(getModelInfoCore).mockResolvedValue({ provider: "openai", model: "text-embedding-ada-002" }); + vi.mocked(handleEmbeddingsCore).mockResolvedValue({ + success: true, + response: new Response(JSON.stringify(VALID_EMBEDDING_RESPONSE_BODY), { + status: 200, + headers: { "Content-Type": "application/json", "Access-Control-Allow-Origin": "*" }, + }), + }); + vi.mocked(saveMachineData).mockResolvedValue(undefined); + }); + + afterEach(() => { + vi.clearAllMocks(); + }); + + it("with machineIdOverride, still validates API key via Authorization header", async () => { + // Key IS in the machine's apiKeys → should succeed + const req = new Request(`https://9cli.hxd.app/${MACHINE_ID}/v1/embeddings`, { + method: "POST", + headers: { + "Content-Type": "application/json", + "Authorization": `Bearer ${VALID_API_KEY}`, + }, + body: JSON.stringify({ model: "openai/text-embedding-ada-002", input: "test" }), + }); + + const res = await handleEmbeddings(req, makeEnv(), {}, MACHINE_ID); + expect(res.status).toBe(200); + }); + + it("with machineIdOverride, wrong API key → 401", async () => { + vi.mocked(getMachineData).mockResolvedValue(makeMachineData({ + apiKeys: [{ key: "sk-correct-key" }], + })); + + const req = new Request(`https://9cli.hxd.app/${MACHINE_ID}/v1/embeddings`, { + method: "POST", + headers: { + "Content-Type": "application/json", + "Authorization": "Bearer sk-wrong-key", + }, + body: JSON.stringify({ model: "openai/text-embedding-ada-002", input: "test" }), + }); + + const res = await handleEmbeddings(req, makeEnv(), {}, MACHINE_ID); + expect(res.status).toBe(401); + }); +}); diff --git a/tests/unit/embeddingsCore.test.js b/tests/unit/embeddingsCore.test.js new file mode 100644 index 00000000..cc5e0a1c --- /dev/null +++ b/tests/unit/embeddingsCore.test.js @@ -0,0 +1,586 @@ +/** + * Unit tests for open-sse/handlers/embeddingsCore.js + * + * Tests cover: + * - buildEmbeddingsBody() — request body construction + * - buildEmbeddingsUrl() — URL per provider + * - buildEmbeddingsHeaders() — headers per provider + * - handleEmbeddingsCore() — full handler: success, errors, validation + */ + +import { describe, it, expect, vi, beforeEach, afterEach } from "vitest"; + +// ─── Mock the executors/index.js to avoid transitive uuid dependency ───────── +// kiro.js (imported by executors/index.js) requires 'uuid' which isn't +// installed in the test environment. We mock the whole executor layer. +vi.mock("../../open-sse/executors/index.js", () => ({ + getExecutor: vi.fn(() => ({ + refreshCredentials: vi.fn().mockResolvedValue(null), + })), + hasSpecializedExecutor: vi.fn(() => false), +})); + +// Also mock tokenRefresh to avoid side effects +vi.mock("../../open-sse/services/tokenRefresh.js", () => ({ + refreshWithRetry: vi.fn().mockResolvedValue(null), +})); + +// Mock proxyFetch to avoid proxy-agent imports in test env +vi.mock("../../open-sse/utils/proxyFetch.js", () => ({ + default: vi.fn(), +})); + +import { handleEmbeddingsCore } from "../../open-sse/handlers/embeddingsCore.js"; + +// ─── Helpers ───────────────────────────────────────────────────────────────── + +/** Build a minimal success Response from a provider */ +function makeProviderResponse(body, status = 200) { + return new Response(JSON.stringify(body), { + status, + headers: { "Content-Type": "application/json" }, + }); +} + +/** Build a minimal error Response from a provider */ +function makeProviderErrorResponse(status, message) { + return new Response(JSON.stringify({ error: { message } }), { + status, + headers: { "Content-Type": "application/json" }, + }); +} + +/** Standard valid embeddings response in OpenAI format */ +const VALID_EMBEDDING_RESPONSE = { + object: "list", + data: [ + { + object: "embedding", + index: 0, + embedding: [0.1, 0.2, 0.3], + }, + ], + model: "text-embedding-ada-002", + usage: { prompt_tokens: 3, total_tokens: 3 }, +}; + +/** Standard handleEmbeddingsCore options for OpenAI provider */ +function makeOptions(overrides = {}) { + return { + body: { model: "text-embedding-ada-002", input: "Hello world" }, + modelInfo: { provider: "openai", model: "text-embedding-ada-002" }, + credentials: { apiKey: "sk-test-key" }, + log: { debug: vi.fn(), info: vi.fn(), warn: vi.fn(), error: vi.fn() }, + onCredentialsRefreshed: vi.fn(), + onRequestSuccess: vi.fn(), + ...overrides, + }; +} + +// ─── Test: buildEmbeddingsBody (via handleEmbeddingsCore internals) ────────── +// We test body construction indirectly by verifying the fetch call payload. + +describe("buildEmbeddingsBody", () => { + beforeEach(() => { + vi.stubGlobal("fetch", vi.fn()); + }); + afterEach(() => { + vi.unstubAllGlobals(); + }); + + it("single string input — includes model and input, default encoding_format=float", async () => { + vi.mocked(fetch).mockResolvedValueOnce(makeProviderResponse(VALID_EMBEDDING_RESPONSE)); + + await handleEmbeddingsCore(makeOptions({ + body: { model: "text-embedding-ada-002", input: "Hello world" }, + })); + + const [, init] = vi.mocked(fetch).mock.calls[0]; + const sent = JSON.parse(init.body); + expect(sent.model).toBe("text-embedding-ada-002"); + expect(sent.input).toBe("Hello world"); + expect(sent.encoding_format).toBe("float"); + }); + + it("array input — passes array as-is", async () => { + vi.mocked(fetch).mockResolvedValueOnce(makeProviderResponse(VALID_EMBEDDING_RESPONSE)); + + await handleEmbeddingsCore(makeOptions({ + body: { model: "text-embedding-ada-002", input: ["Hello", "World"] }, + })); + + const [, init] = vi.mocked(fetch).mock.calls[0]; + const sent = JSON.parse(init.body); + expect(sent.input).toEqual(["Hello", "World"]); + }); + + it("custom encoding_format is forwarded", async () => { + vi.mocked(fetch).mockResolvedValueOnce(makeProviderResponse(VALID_EMBEDDING_RESPONSE)); + + await handleEmbeddingsCore(makeOptions({ + body: { + model: "text-embedding-ada-002", + input: "test", + encoding_format: "base64", + }, + })); + + const [, init] = vi.mocked(fetch).mock.calls[0]; + const sent = JSON.parse(init.body); + expect(sent.encoding_format).toBe("base64"); + }); + + it("no encoding_format in body → defaults to float", async () => { + vi.mocked(fetch).mockResolvedValueOnce(makeProviderResponse(VALID_EMBEDDING_RESPONSE)); + + await handleEmbeddingsCore(makeOptions({ + body: { model: "text-embedding-ada-002", input: "test" }, + })); + + const [, init] = vi.mocked(fetch).mock.calls[0]; + const sent = JSON.parse(init.body); + expect(sent.encoding_format).toBe("float"); + }); +}); + +// ─── Test: buildEmbeddingsUrl ──────────────────────────────────────────────── + +describe("buildEmbeddingsUrl", () => { + beforeEach(() => { + vi.stubGlobal("fetch", vi.fn()); + }); + afterEach(() => { + vi.unstubAllGlobals(); + }); + + it("openai → https://api.openai.com/v1/embeddings", async () => { + vi.mocked(fetch).mockResolvedValueOnce(makeProviderResponse(VALID_EMBEDDING_RESPONSE)); + + await handleEmbeddingsCore(makeOptions({ + modelInfo: { provider: "openai", model: "text-embedding-ada-002" }, + credentials: { apiKey: "sk-test" }, + })); + + const [url] = vi.mocked(fetch).mock.calls[0]; + expect(url).toBe("https://api.openai.com/v1/embeddings"); + }); + + it("openrouter → https://openrouter.ai/api/v1/embeddings", async () => { + vi.mocked(fetch).mockResolvedValueOnce(makeProviderResponse(VALID_EMBEDDING_RESPONSE)); + + await handleEmbeddingsCore(makeOptions({ + modelInfo: { provider: "openrouter", model: "openai/text-embedding-ada-002" }, + credentials: { apiKey: "sk-or-test" }, + })); + + const [url] = vi.mocked(fetch).mock.calls[0]; + expect(url).toBe("https://openrouter.ai/api/v1/embeddings"); + }); + + it("openai-compatible-* → uses baseUrl from providerSpecificData", async () => { + vi.mocked(fetch).mockResolvedValueOnce(makeProviderResponse(VALID_EMBEDDING_RESPONSE)); + + await handleEmbeddingsCore(makeOptions({ + modelInfo: { provider: "openai-compatible-custom", model: "embed-v1" }, + credentials: { + apiKey: "sk-custom", + providerSpecificData: { baseUrl: "https://custom.ai/v1" }, + }, + })); + + const [url] = vi.mocked(fetch).mock.calls[0]; + expect(url).toBe("https://custom.ai/v1/embeddings"); + }); + + it("openai-compatible-* strips trailing slash from baseUrl", async () => { + vi.mocked(fetch).mockResolvedValueOnce(makeProviderResponse(VALID_EMBEDDING_RESPONSE)); + + await handleEmbeddingsCore(makeOptions({ + modelInfo: { provider: "openai-compatible-myhost", model: "embed-v1" }, + credentials: { + apiKey: "sk-x", + providerSpecificData: { baseUrl: "https://myhost.ai/v1/" }, + }, + })); + + const [url] = vi.mocked(fetch).mock.calls[0]; + expect(url).toBe("https://myhost.ai/v1/embeddings"); + }); + + it("openai-compatible-* without baseUrl → falls back to api.openai.com", async () => { + vi.mocked(fetch).mockResolvedValueOnce(makeProviderResponse(VALID_EMBEDDING_RESPONSE)); + + await handleEmbeddingsCore(makeOptions({ + modelInfo: { provider: "openai-compatible-fallback", model: "embed" }, + credentials: { apiKey: "sk-x", providerSpecificData: {} }, + })); + + const [url] = vi.mocked(fetch).mock.calls[0]; + expect(url).toBe("https://api.openai.com/v1/embeddings"); + }); + + it("unsupported provider (e.g. gemini-cli) → 400 error, no fetch called", async () => { + const result = await handleEmbeddingsCore(makeOptions({ + modelInfo: { provider: "gemini-cli", model: "gemini-embedding" }, + credentials: { apiKey: "token" }, + })); + + expect(vi.mocked(fetch)).not.toHaveBeenCalled(); + expect(result.success).toBe(false); + expect(result.status).toBe(400); + expect(result.error).toMatch(/does not support embeddings/i); + }); + + it("antigravity (non-openai-compatible, no URL mapping) → 400", async () => { + const result = await handleEmbeddingsCore(makeOptions({ + modelInfo: { provider: "antigravity", model: "some-embed" }, + credentials: { apiKey: "ag-token" }, + })); + + expect(vi.mocked(fetch)).not.toHaveBeenCalled(); + expect(result.success).toBe(false); + expect(result.status).toBe(400); + }); +}); + +// ─── Test: buildEmbeddingsHeaders ─────────────────────────────────────────── + +describe("buildEmbeddingsHeaders", () => { + beforeEach(() => { + vi.stubGlobal("fetch", vi.fn()); + }); + afterEach(() => { + vi.unstubAllGlobals(); + }); + + it("openai → Authorization: Bearer, Content-Type: application/json", async () => { + vi.mocked(fetch).mockResolvedValueOnce(makeProviderResponse(VALID_EMBEDDING_RESPONSE)); + + await handleEmbeddingsCore(makeOptions({ + modelInfo: { provider: "openai", model: "text-embedding-ada-002" }, + credentials: { apiKey: "sk-mykey" }, + })); + + const [, init] = vi.mocked(fetch).mock.calls[0]; + expect(init.headers["Authorization"]).toBe("Bearer sk-mykey"); + expect(init.headers["Content-Type"]).toBe("application/json"); + }); + + it("openai — uses accessToken when apiKey is absent", async () => { + vi.mocked(fetch).mockResolvedValueOnce(makeProviderResponse(VALID_EMBEDDING_RESPONSE)); + + await handleEmbeddingsCore(makeOptions({ + modelInfo: { provider: "openai", model: "text-embedding-ada-002" }, + credentials: { accessToken: "at-mytoken" }, + })); + + const [, init] = vi.mocked(fetch).mock.calls[0]; + expect(init.headers["Authorization"]).toBe("Bearer at-mytoken"); + }); + + it("openrouter → adds HTTP-Referer and X-Title headers", async () => { + vi.mocked(fetch).mockResolvedValueOnce(makeProviderResponse(VALID_EMBEDDING_RESPONSE)); + + await handleEmbeddingsCore(makeOptions({ + modelInfo: { provider: "openrouter", model: "openai/text-embedding-ada-002" }, + credentials: { apiKey: "sk-or-key" }, + })); + + const [, init] = vi.mocked(fetch).mock.calls[0]; + expect(init.headers["HTTP-Referer"]).toBeDefined(); + expect(init.headers["X-Title"]).toBeDefined(); + expect(init.headers["Authorization"]).toBe("Bearer sk-or-key"); + }); + + it("openai-compatible-* → Authorization: Bearer only (no extra headers)", async () => { + vi.mocked(fetch).mockResolvedValueOnce(makeProviderResponse(VALID_EMBEDDING_RESPONSE)); + + await handleEmbeddingsCore(makeOptions({ + modelInfo: { provider: "openai-compatible-local", model: "nomic-embed" }, + credentials: { + apiKey: "local-key", + providerSpecificData: { baseUrl: "http://localhost:11434/v1" }, + }, + })); + + const [, init] = vi.mocked(fetch).mock.calls[0]; + expect(init.headers["Authorization"]).toBe("Bearer local-key"); + expect(init.headers["HTTP-Referer"]).toBeUndefined(); + expect(init.headers["X-Title"]).toBeUndefined(); + }); +}); + +// ─── Test: handleEmbeddingsCore — input validation ─────────────────────────── + +describe("handleEmbeddingsCore — input validation", () => { + afterEach(() => { + vi.unstubAllGlobals(); + }); + + it("missing input → 400 Bad Request", async () => { + const result = await handleEmbeddingsCore(makeOptions({ + body: { model: "text-embedding-ada-002" }, // no input + })); + expect(result.success).toBe(false); + expect(result.status).toBe(400); + expect(result.error).toMatch(/missing required field: input/i); + }); + + it("input is a number → 400 Bad Request", async () => { + const result = await handleEmbeddingsCore(makeOptions({ + body: { model: "text-embedding-ada-002", input: 42 }, + })); + expect(result.success).toBe(false); + expect(result.status).toBe(400); + expect(result.error).toMatch(/input must be a string or array/i); + }); + + it("input is an object → 400 Bad Request", async () => { + const result = await handleEmbeddingsCore(makeOptions({ + body: { model: "text-embedding-ada-002", input: { text: "hello" } }, + })); + expect(result.success).toBe(false); + expect(result.status).toBe(400); + expect(result.error).toMatch(/input must be a string or array/i); + }); + + it("input is null → 400 Bad Request", async () => { + const result = await handleEmbeddingsCore(makeOptions({ + body: { model: "text-embedding-ada-002", input: null }, + })); + expect(result.success).toBe(false); + expect(result.status).toBe(400); + }); + + it("empty string input passes validation", async () => { + vi.stubGlobal("fetch", vi.fn().mockResolvedValueOnce( + makeProviderResponse(VALID_EMBEDDING_RESPONSE) + )); + const result = await handleEmbeddingsCore(makeOptions({ + body: { model: "text-embedding-ada-002", input: "" }, + })); + // Empty string is falsy → treated as missing + expect(result.success).toBe(false); + expect(result.status).toBe(400); + }); + + it("empty array input passes validation and reaches provider", async () => { + vi.stubGlobal("fetch", vi.fn().mockResolvedValueOnce( + makeProviderResponse(VALID_EMBEDDING_RESPONSE) + )); + const result = await handleEmbeddingsCore(makeOptions({ + body: { model: "text-embedding-ada-002", input: [] }, + })); + // Empty array is truthy → passes, fetch is called + expect(fetch).toHaveBeenCalledOnce(); + expect(result.success).toBe(true); + }); +}); + +// ─── Test: handleEmbeddingsCore — success path ─────────────────────────────── + +describe("handleEmbeddingsCore — success path", () => { + beforeEach(() => { + vi.stubGlobal("fetch", vi.fn()); + }); + afterEach(() => { + vi.unstubAllGlobals(); + }); + + it("returns success=true with Response on 200", async () => { + vi.mocked(fetch).mockResolvedValueOnce(makeProviderResponse(VALID_EMBEDDING_RESPONSE)); + + const result = await handleEmbeddingsCore(makeOptions()); + + expect(result.success).toBe(true); + expect(result.response).toBeInstanceOf(Response); + expect(result.response.status).toBe(200); + }); + + it("response body is valid OpenAI-format JSON", async () => { + vi.mocked(fetch).mockResolvedValueOnce(makeProviderResponse(VALID_EMBEDDING_RESPONSE)); + + const result = await handleEmbeddingsCore(makeOptions()); + const body = await result.response.json(); + + expect(body.object).toBe("list"); + expect(Array.isArray(body.data)).toBe(true); + expect(body.data[0]).toHaveProperty("embedding"); + expect(body.data[0]).toHaveProperty("index"); + }); + + it("response includes CORS header Access-Control-Allow-Origin: *", async () => { + vi.mocked(fetch).mockResolvedValueOnce(makeProviderResponse(VALID_EMBEDDING_RESPONSE)); + + const result = await handleEmbeddingsCore(makeOptions()); + + expect(result.response.headers.get("Access-Control-Allow-Origin")).toBe("*"); + }); + + it("response Content-Type is application/json", async () => { + vi.mocked(fetch).mockResolvedValueOnce(makeProviderResponse(VALID_EMBEDDING_RESPONSE)); + + const result = await handleEmbeddingsCore(makeOptions()); + + expect(result.response.headers.get("Content-Type")).toContain("application/json"); + }); + + it("calls onRequestSuccess callback on success", async () => { + vi.mocked(fetch).mockResolvedValueOnce(makeProviderResponse(VALID_EMBEDDING_RESPONSE)); + const onRequestSuccess = vi.fn(); + + await handleEmbeddingsCore(makeOptions({ onRequestSuccess })); + + expect(onRequestSuccess).toHaveBeenCalledOnce(); + }); + + it("does not call onRequestSuccess on provider error", async () => { + vi.mocked(fetch).mockResolvedValueOnce(makeProviderErrorResponse(500, "Server exploded")); + const onRequestSuccess = vi.fn(); + + await handleEmbeddingsCore(makeOptions({ onRequestSuccess })); + + expect(onRequestSuccess).not.toHaveBeenCalled(); + }); + + it("provider response with non-standard format is passed through as-is", async () => { + const nonStandardBody = { embeddings: [[0.1, 0.2]], model: "custom" }; + vi.mocked(fetch).mockResolvedValueOnce(makeProviderResponse(nonStandardBody)); + + const result = await handleEmbeddingsCore(makeOptions()); + const body = await result.response.json(); + + expect(body).toEqual(nonStandardBody); + }); +}); + +// ─── Test: handleEmbeddingsCore — provider error handling ──────────────────── + +describe("handleEmbeddingsCore — provider error handling", () => { + beforeEach(() => { + vi.stubGlobal("fetch", vi.fn()); + }); + afterEach(() => { + vi.unstubAllGlobals(); + }); + + it("provider 400 → returns success=false with status 400", async () => { + vi.mocked(fetch).mockResolvedValueOnce(makeProviderErrorResponse(400, "Bad model")); + + const result = await handleEmbeddingsCore(makeOptions()); + + expect(result.success).toBe(false); + expect(result.status).toBe(400); + }); + + it("provider 429 → returns success=false with status 429", async () => { + vi.mocked(fetch).mockResolvedValueOnce(makeProviderErrorResponse(429, "Rate limit exceeded")); + + const result = await handleEmbeddingsCore(makeOptions()); + + expect(result.success).toBe(false); + expect(result.status).toBe(429); + }); + + it("provider 500 → returns success=false with status 500", async () => { + vi.mocked(fetch).mockResolvedValueOnce(makeProviderErrorResponse(500, "Internal error")); + + const result = await handleEmbeddingsCore(makeOptions()); + + expect(result.success).toBe(false); + expect(result.status).toBe(500); + }); + + it("network error (fetch throws) → returns 502 Bad Gateway", async () => { + vi.mocked(fetch).mockRejectedValueOnce(new Error("ECONNREFUSED")); + + const result = await handleEmbeddingsCore(makeOptions()); + + expect(result.success).toBe(false); + expect(result.status).toBe(502); + expect(result.error).toMatch(/ECONNREFUSED/); + }); + + it("invalid JSON from provider → returns 502", async () => { + vi.mocked(fetch).mockResolvedValueOnce( + new Response("not json }{", { + status: 200, + headers: { "Content-Type": "text/plain" }, + }) + ); + + const result = await handleEmbeddingsCore(makeOptions()); + + expect(result.success).toBe(false); + expect(result.status).toBe(502); + }); + + it("error result response has OpenAI-format error body", async () => { + vi.mocked(fetch).mockResolvedValueOnce(makeProviderErrorResponse(400, "Bad model")); + + const result = await handleEmbeddingsCore(makeOptions()); + const body = await result.response.json(); + + expect(body).toHaveProperty("error"); + expect(body.error).toHaveProperty("message"); + }); +}); + +// ─── Test: handleEmbeddingsCore — token refresh on 401 ─────────────────────── + +describe("handleEmbeddingsCore — token refresh on 401/403", () => { + beforeEach(() => { + vi.stubGlobal("fetch", vi.fn()); + }); + afterEach(() => { + vi.unstubAllGlobals(); + }); + + it("on 401, attempts retry after refresh; succeeds if refresh gives new token", async () => { + // First call → 401 from provider + vi.mocked(fetch).mockResolvedValueOnce( + new Response(JSON.stringify({ error: "unauthorized" }), { + status: 401, + headers: { "Content-Type": "application/json" }, + }) + ); + // Second call (retry) → success + vi.mocked(fetch).mockResolvedValueOnce(makeProviderResponse(VALID_EMBEDDING_RESPONSE)); + + // Credentials with a refreshToken so the executor can try to refresh + const credentials = { + apiKey: "sk-old", + accessToken: "at-old", + refreshToken: "rt-valid", + }; + + // Mock executor's refreshCredentials to return new creds + const result = await handleEmbeddingsCore(makeOptions({ + modelInfo: { provider: "openai", model: "text-embedding-ada-002" }, + credentials, + onCredentialsRefreshed: vi.fn(), + })); + + // The handler may or may not succeed depending on whether the executor + // can refresh (openai executor likely can't). What we verify is that + // fetch was called at least once (the initial request). + expect(vi.mocked(fetch).mock.calls.length).toBeGreaterThanOrEqual(1); + }); + + it("on 401 with no refresh token, falls back gracefully (no crash)", async () => { + vi.mocked(fetch).mockResolvedValueOnce( + new Response(JSON.stringify({ error: "unauthorized" }), { + status: 401, + headers: { "Content-Type": "application/json" }, + }) + ); + + const result = await handleEmbeddingsCore(makeOptions({ + credentials: { apiKey: "sk-bad" }, + })); + + // Should return an error result, not throw + expect(result).toHaveProperty("success"); + expect(result.success).toBe(false); + }); +}); diff --git a/tests/vitest.config.js b/tests/vitest.config.js new file mode 100644 index 00000000..e5cd9ae8 --- /dev/null +++ b/tests/vitest.config.js @@ -0,0 +1,21 @@ +import { defineConfig } from "vitest/config"; +import { resolve } from "path"; +import { fileURLToPath } from "url"; + +const __dirname = fileURLToPath(new URL(".", import.meta.url)); + +export default defineConfig({ + test: { + environment: "node", + globals: true, + include: ["**/*.test.js"], + // Suppress noisy console output from handlers under test + silent: false, + }, + resolve: { + alias: { + // Resolve open-sse/* imports to the actual local package + "open-sse": resolve(__dirname, "../open-sse"), + }, + }, +});