diff --git a/open-sse/services/combo.js b/open-sse/services/combo.js index 34598b64..6a4b17df 100644 --- a/open-sse/services/combo.js +++ b/open-sse/services/combo.js @@ -7,35 +7,60 @@ import { unavailableResponse } from "../utils/error.js"; /** * Track rotation state per combo (for round-robin strategy) - * @type {Map} + * @type {Map} */ const comboRotationState = new Map(); +function normalizeStickyLimit(stickyLimit) { + const parsed = Number.parseInt(stickyLimit, 10); + return Number.isFinite(parsed) && parsed > 0 ? parsed : 1; +} + +function rotateModelsFromIndex(models, currentIndex) { + const rotatedModels = [...models]; + for (let i = 0; i < currentIndex; i++) { + const moved = rotatedModels.shift(); + rotatedModels.push(moved); + } + return rotatedModels; +} + /** * Get rotated model list based on strategy * @param {string[]} models - Array of model strings * @param {string} comboName - Name of the combo * @param {string} strategy - "fallback" or "round-robin" + * @param {number|string} [stickyLimit=1] - Requests per combo model before switching * @returns {string[]} Rotated models array */ -export function getRotatedModels(models, comboName, strategy) { +export function getRotatedModels(models, comboName, strategy, stickyLimit = 1) { if (!models || models.length <= 1 || strategy !== "round-robin") { return models; } - const currentIndex = comboRotationState.get(comboName) || 0; - const rotatedModels = [...models]; - - // Rotate: move models from currentIndex to front, preserving order after - for (let i = 0; i < currentIndex; i++) { - const moved = rotatedModels.shift(); - rotatedModels.push(moved); + const rotationKey = comboName || "__default__"; + const normalizedStickyLimit = normalizeStickyLimit(stickyLimit); + const existingState = comboRotationState.get(rotationKey); + const state = typeof existingState === "number" + ? { index: existingState, consecutiveUseCount: 0 } + : (existingState || { index: 0, consecutiveUseCount: 0 }); + + const currentIndex = state.index % models.length; + const rotatedModels = rotateModelsFromIndex(models, currentIndex); + const nextUseCount = state.consecutiveUseCount + 1; + + if (nextUseCount >= normalizedStickyLimit) { + comboRotationState.set(rotationKey, { + index: (currentIndex + 1) % models.length, + consecutiveUseCount: 0, + }); + } else { + comboRotationState.set(rotationKey, { + index: currentIndex, + consecutiveUseCount: nextUseCount, + }); } - - // Update state for next request (cycle through all models) - const nextIndex = (currentIndex + 1) % models.length; - comboRotationState.set(comboName, nextIndex); - + return rotatedModels; } @@ -77,11 +102,12 @@ export function getComboModelsFromData(modelStr, combosData) { * @param {Object} options.log - Logger object * @param {string} [options.comboName] - Name of the combo (for round-robin tracking) * @param {string} [options.comboStrategy] - Strategy: "fallback" or "round-robin" + * @param {number|string} [options.comboStickyLimit=1] - Requests per combo model before switching * @returns {Promise} */ -export async function handleComboChat({ body, models, handleSingleModel, log, comboName, comboStrategy }) { +export async function handleComboChat({ body, models, handleSingleModel, log, comboName, comboStrategy, comboStickyLimit = 1 }) { // Apply rotation strategy if enabled - const rotatedModels = getRotatedModels(models, comboName, comboStrategy); + const rotatedModels = getRotatedModels(models, comboName, comboStrategy, comboStickyLimit); let lastError = null; let earliestRetryAfter = null; diff --git a/src/app/(dashboard)/dashboard/profile/page.js b/src/app/(dashboard)/dashboard/profile/page.js index 203020c6..4c55e455 100644 --- a/src/app/(dashboard)/dashboard/profile/page.js +++ b/src/app/(dashboard)/dashboard/profile/page.js @@ -223,6 +223,24 @@ export default function ProfilePage() { } }; + const updateComboStickyLimit = async (limit) => { + const numLimit = parseInt(limit); + if (isNaN(numLimit) || numLimit < 1) return; + + try { + const res = await fetch("/api/settings", { + method: "PATCH", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ comboStickyRoundRobinLimit: numLimit }), + }); + if (res.ok) { + setSettings(prev => ({ ...prev, comboStickyRoundRobinLimit: numLimit })); + } + } catch (err) { + console.error("Failed to update combo sticky limit:", err); + } + }; + const updateRequireLogin = async (requireLogin) => { try { const res = await fetch("/api/settings", { @@ -550,10 +568,34 @@ export default function ProfilePage() { /> + {/* Combo Sticky Round Robin Limit */} + {settings.comboStrategy === "round-robin" && ( +
+
+

Combo Sticky Limit

+

+ Calls per combo model before switching +

+
+ updateComboStickyLimit(e.target.value)} + disabled={loading} + className="w-20 text-center" + /> +
+ )} +

{settings.fallbackStrategy === "round-robin" ? `Currently distributing requests across all available accounts with ${settings.stickyRoundRobinLimit || 3} calls per account.` : "Currently using accounts in priority order (Fill First)."} + {settings.comboStrategy === "round-robin" + ? ` Combos rotate after ${settings.comboStickyRoundRobinLimit || 1} call${(settings.comboStickyRoundRobinLimit || 1) === 1 ? "" : "s"} per model.` + : " Combos always start with their first model."}

diff --git a/src/app/api/settings/route.js b/src/app/api/settings/route.js index ddd91c19..50593b36 100644 --- a/src/app/api/settings/route.js +++ b/src/app/api/settings/route.js @@ -70,6 +70,7 @@ export async function PATCH(request) { // Invalidate combo rotation state when strategy settings change if ( Object.prototype.hasOwnProperty.call(body, "comboStrategy") || + Object.prototype.hasOwnProperty.call(body, "comboStickyRoundRobinLimit") || Object.prototype.hasOwnProperty.call(body, "comboStrategies") ) { resetComboRotation(); diff --git a/src/lib/localDb.js b/src/lib/localDb.js index 8ad959aa..0ea2d6a4 100644 --- a/src/lib/localDb.js +++ b/src/lib/localDb.js @@ -24,6 +24,7 @@ const DEFAULT_SETTINGS = { stickyRoundRobinLimit: 3, providerStrategies: {}, comboStrategy: "fallback", + comboStickyRoundRobinLimit: 1, comboStrategies: {}, requireLogin: true, tunnelDashboardAccess: true, diff --git a/src/sse/handlers/chat.js b/src/sse/handlers/chat.js index 60a718b5..0d7e4a9f 100644 --- a/src/sse/handlers/chat.js +++ b/src/sse/handlers/chat.js @@ -97,14 +97,16 @@ export async function handleChat(request, clientRawRequest = null) { const comboSpecificStrategy = comboStrategies[modelStr]?.fallbackStrategy; const comboStrategy = comboSpecificStrategy || settings.comboStrategy || "fallback"; - log.info("CHAT", `Combo "${modelStr}" with ${comboModels.length} models (strategy: ${comboStrategy})`); + const comboStickyLimit = settings.comboStickyRoundRobinLimit; + log.info("CHAT", `Combo "${modelStr}" with ${comboModels.length} models (strategy: ${comboStrategy}, sticky: ${comboStickyLimit})`); return handleComboChat({ body, models: comboModels, handleSingleModel: (b, m) => handleSingleModelChat(b, m, clientRawRequest, request, apiKey), log, comboName: modelStr, - comboStrategy + comboStrategy, + comboStickyLimit }); } @@ -128,14 +130,16 @@ async function handleSingleModelChat(body, modelStr, clientRawRequest = null, re const comboSpecificStrategy = comboStrategies[modelStr]?.fallbackStrategy; const comboStrategy = comboSpecificStrategy || chatSettings.comboStrategy || "fallback"; - log.info("CHAT", `Combo "${modelStr}" with ${comboModels.length} models (strategy: ${comboStrategy})`); + const comboStickyLimit = chatSettings.comboStickyRoundRobinLimit; + log.info("CHAT", `Combo "${modelStr}" with ${comboModels.length} models (strategy: ${comboStrategy}, sticky: ${comboStickyLimit})`); return handleComboChat({ body, models: comboModels, handleSingleModel: (b, m) => handleSingleModelChat(b, m, clientRawRequest, request, apiKey), log, comboName: modelStr, - comboStrategy + comboStrategy, + comboStickyLimit }); } log.warn("CHAT", "Invalid model format", { model: modelStr }); diff --git a/src/sse/handlers/fetch.js b/src/sse/handlers/fetch.js index 543eb586..6d096d9a 100644 --- a/src/sse/handlers/fetch.js +++ b/src/sse/handlers/fetch.js @@ -84,14 +84,16 @@ export async function handleFetch(request) { if (comboModels) { const comboStrategies = settings.comboStrategies || {}; const comboStrategy = comboStrategies[providerInput]?.fallbackStrategy || settings.comboStrategy || "fallback"; - log.info("FETCH", `Combo "${providerInput}" with ${comboModels.length} providers (strategy: ${comboStrategy})`); + const comboStickyLimit = settings.comboStickyRoundRobinLimit; + log.info("FETCH", `Combo "${providerInput}" with ${comboModels.length} providers (strategy: ${comboStrategy}, sticky: ${comboStickyLimit})`); return handleComboChat({ body, models: comboModels, handleSingleModel: (b, m) => handleSingleProviderFetch(b, m, request, apiKey, settings), log, comboName: providerInput, - comboStrategy + comboStrategy, + comboStickyLimit }); } diff --git a/src/sse/handlers/search.js b/src/sse/handlers/search.js index 2d062818..d8ee6b74 100644 --- a/src/sse/handlers/search.js +++ b/src/sse/handlers/search.js @@ -74,14 +74,16 @@ export async function handleSearch(request) { if (comboModels) { const comboStrategies = settings.comboStrategies || {}; const comboStrategy = comboStrategies[providerInput]?.fallbackStrategy || settings.comboStrategy || "fallback"; - log.info("SEARCH", `Combo "${providerInput}" with ${comboModels.length} providers (strategy: ${comboStrategy})`); + const comboStickyLimit = settings.comboStickyRoundRobinLimit; + log.info("SEARCH", `Combo "${providerInput}" with ${comboModels.length} providers (strategy: ${comboStrategy}, sticky: ${comboStickyLimit})`); return handleComboChat({ body, models: comboModels, handleSingleModel: (b, m) => handleSingleProviderSearch(b, m, request, apiKey, settings), log, comboName: providerInput, - comboStrategy + comboStrategy, + comboStickyLimit }); } diff --git a/tests/unit/combo-routing.test.js b/tests/unit/combo-routing.test.js new file mode 100644 index 00000000..d6ef4d04 --- /dev/null +++ b/tests/unit/combo-routing.test.js @@ -0,0 +1,58 @@ +import { describe, it, expect, beforeEach } from "vitest"; + +import { getRotatedModels, resetComboRotation } from "../../open-sse/services/combo.js"; + +describe("combo round-robin routing", () => { + beforeEach(() => { + resetComboRotation(); + }); + + it("keeps existing one-request round-robin behavior by default", () => { + const models = ["provider/model-a", "provider/model-b"]; + + const firstChoices = Array.from({ length: 4 }, () => ( + getRotatedModels(models, "code-xhigh", "round-robin")[0] + )); + + expect(firstChoices).toEqual([ + "provider/model-a", + "provider/model-b", + "provider/model-a", + "provider/model-b", + ]); + }); + + it("sticks to each combo model for the configured number of requests", () => { + const models = ["provider/model-a", "provider/model-b"]; + + const firstChoices = Array.from({ length: 6 }, () => ( + getRotatedModels(models, "code-xhigh", "round-robin", 2)[0] + )); + + expect(firstChoices).toEqual([ + "provider/model-a", + "provider/model-a", + "provider/model-b", + "provider/model-b", + "provider/model-a", + "provider/model-a", + ]); + }); + + it("tracks sticky rotation independently per combo", () => { + const models = ["provider/model-a", "provider/model-b"]; + + expect(getRotatedModels(models, "code-high", "round-robin", 2)[0]).toBe("provider/model-a"); + expect(getRotatedModels(models, "code-xhigh", "round-robin", 2)[0]).toBe("provider/model-a"); + expect(getRotatedModels(models, "code-high", "round-robin", 2)[0]).toBe("provider/model-a"); + expect(getRotatedModels(models, "code-high", "round-robin", 2)[0]).toBe("provider/model-b"); + expect(getRotatedModels(models, "code-xhigh", "round-robin", 2)[0]).toBe("provider/model-a"); + }); + + it("does not rotate fallback combos", () => { + const models = ["provider/model-a", "provider/model-b"]; + + expect(getRotatedModels(models, "code-xhigh", "fallback", 2)).toEqual(models); + expect(getRotatedModels(models, "code-xhigh", "fallback", 2)).toEqual(models); + }); +});