From 3e9523c49f005eb8a9b7170346af15ea7a25e8bf Mon Sep 17 00:00:00 2001 From: GuillaumeSD Date: Sun, 20 Apr 2025 03:49:15 +0200 Subject: [PATCH] fix : engine parallel workers --- src/hooks/useEngine.ts | 26 ++-- src/lib/engine/stockfish11.ts | 4 +- src/lib/engine/stockfish16.ts | 13 +- src/lib/engine/stockfish16_1.ts | 7 +- src/lib/engine/stockfish17.ts | 7 +- src/lib/engine/uciEngine.ts | 142 ++++++++---------- src/lib/engine/worker.ts | 18 ++- .../analysis/hooks/useCurrentPosition.ts | 6 +- .../analysis/panelHeader/analyzeButton.tsx | 8 +- src/sections/analysis/states.ts | 2 +- src/sections/play/board.tsx | 4 +- src/types/engine.ts | 8 + 12 files changed, 137 insertions(+), 108 deletions(-) diff --git a/src/hooks/useEngine.ts b/src/hooks/useEngine.ts index e51a6f7..23ac970 100644 --- a/src/hooks/useEngine.ts +++ b/src/hooks/useEngine.ts @@ -7,7 +7,10 @@ import { UciEngine } from "@/lib/engine/uciEngine"; import { EngineName } from "@/types/enums"; import { useEffect, useState } from "react"; -export const useEngine = (engineName: EngineName | undefined) => { +export const useEngine = ( + engineName: EngineName | undefined, + workersNb?: number +) => { const [engine, setEngine] = useState(null); useEffect(() => { @@ -17,7 +20,7 @@ export const useEngine = (engineName: EngineName | undefined) => { return; } - pickEngine(engineName).then((newEngine) => { + pickEngine(engineName, workersNb).then((newEngine) => { setEngine((prev) => { prev?.shutdown(); return newEngine; @@ -28,21 +31,24 @@ export const useEngine = (engineName: EngineName | undefined) => { return engine; }; -const pickEngine = (engine: EngineName): Promise => { +const pickEngine = ( + engine: EngineName, + workersNb?: number +): Promise => { switch (engine) { case EngineName.Stockfish17: - return Stockfish17.create(false); + return Stockfish17.create(false, workersNb); case EngineName.Stockfish17Lite: - return Stockfish17.create(true); + return Stockfish17.create(true, workersNb); case EngineName.Stockfish16_1: - return Stockfish16_1.create(false); + return Stockfish16_1.create(false, workersNb); case EngineName.Stockfish16_1Lite: - return Stockfish16_1.create(true); + return Stockfish16_1.create(true, workersNb); case EngineName.Stockfish16: - return Stockfish16.create(false); + return Stockfish16.create(false, workersNb); case EngineName.Stockfish16NNUE: - return Stockfish16.create(true); + return Stockfish16.create(true, workersNb); case EngineName.Stockfish11: - return Stockfish11.create(); + return Stockfish11.create(workersNb); } }; diff --git a/src/lib/engine/stockfish11.ts b/src/lib/engine/stockfish11.ts index 1c6f28a..f554e31 100644 --- a/src/lib/engine/stockfish11.ts +++ b/src/lib/engine/stockfish11.ts @@ -3,8 +3,8 @@ import { UciEngine } from "./uciEngine"; import { getEngineWorkers } from "./worker"; export class Stockfish11 { - public static async create(): Promise { - const workers = getEngineWorkers("engines/stockfish-11.js"); + public static async create(workersNb?: number): Promise { + const workers = getEngineWorkers("engines/stockfish-11.js", workersNb); return UciEngine.create(EngineName.Stockfish11, workers); } diff --git a/src/lib/engine/stockfish16.ts b/src/lib/engine/stockfish16.ts index ee463cc..a93e927 100644 --- a/src/lib/engine/stockfish16.ts +++ b/src/lib/engine/stockfish16.ts @@ -4,7 +4,10 @@ import { isMultiThreadSupported, isWasmSupported } from "./shared"; import { getEngineWorkers } from "./worker"; export class Stockfish16 { - public static async create(nnue?: boolean): Promise { + public static async create( + nnue?: boolean, + workersNb?: number + ): Promise { if (!Stockfish16.isSupported()) { throw new Error("Stockfish 16 is not supported"); } @@ -25,9 +28,13 @@ export class Stockfish16 { ); }; - const workers = getEngineWorkers(enginePath); + const engineName = nnue + ? EngineName.Stockfish16NNUE + : EngineName.Stockfish16; - return UciEngine.create(EngineName.Stockfish16, workers, customEngineInit); + const workers = getEngineWorkers(enginePath, workersNb); + + return UciEngine.create(engineName, workers, customEngineInit); } public static isSupported() { diff --git a/src/lib/engine/stockfish16_1.ts b/src/lib/engine/stockfish16_1.ts index 6a8232a..8caeddd 100644 --- a/src/lib/engine/stockfish16_1.ts +++ b/src/lib/engine/stockfish16_1.ts @@ -4,7 +4,10 @@ import { isMultiThreadSupported, isWasmSupported } from "./shared"; import { getEngineWorkers } from "./worker"; export class Stockfish16_1 { - public static async create(lite?: boolean): Promise { + public static async create( + lite?: boolean, + workersNb?: number + ): Promise { if (!Stockfish16_1.isSupported()) { throw new Error("Stockfish 16.1 is not supported"); } @@ -20,7 +23,7 @@ export class Stockfish16_1 { ? EngineName.Stockfish16_1Lite : EngineName.Stockfish16_1; - const workers = getEngineWorkers(enginePath); + const workers = getEngineWorkers(enginePath, workersNb); return UciEngine.create(engineName, workers); } diff --git a/src/lib/engine/stockfish17.ts b/src/lib/engine/stockfish17.ts index 49a7ed1..3932334 100644 --- a/src/lib/engine/stockfish17.ts +++ b/src/lib/engine/stockfish17.ts @@ -4,7 +4,10 @@ import { isMultiThreadSupported, isWasmSupported } from "./shared"; import { getEngineWorkers } from "./worker"; export class Stockfish17 { - public static async create(lite?: boolean): Promise { + public static async create( + lite?: boolean, + workersNb?: number + ): Promise { if (!Stockfish17.isSupported()) { throw new Error("Stockfish 17 is not supported"); } @@ -20,7 +23,7 @@ export class Stockfish17 { ? EngineName.Stockfish17Lite : EngineName.Stockfish17; - const workers = getEngineWorkers(enginePath); + const workers = getEngineWorkers(enginePath, workersNb); return UciEngine.create(engineName, workers); } diff --git a/src/lib/engine/uciEngine.ts b/src/lib/engine/uciEngine.ts index 60fb4da..3ca2795 100644 --- a/src/lib/engine/uciEngine.ts +++ b/src/lib/engine/uciEngine.ts @@ -13,20 +13,12 @@ import { computeAccuracy } from "./helpers/accuracy"; import { getIsStalemate, getWhoIsCheckmated } from "../chess"; import { getLichessEval } from "../lichess"; import { getMovesClassification } from "./helpers/moveClassification"; -import { EngineWorker } from "@/types/engine"; - -type WorkerJob = { - commands: string[]; - finalMessage: string; - onNewMessage?: (messages: string[]) => void; - resolve: (messages: string[]) => void; -}; +import { EngineWorker, WorkerJob } from "@/types/engine"; export class UciEngine { private workers: EngineWorker[]; - private isBusy: boolean[] = []; private workerQueue: WorkerJob[] = []; - private ready = false; + private isReady = false; private engineName: EngineName; private multiPv = 3; private skillLevel: number | undefined = undefined; @@ -34,7 +26,6 @@ export class UciEngine { private constructor(engineName: EngineName, workers: EngineWorker[]) { this.engineName = engineName; this.workers = workers; - this.isBusy = new Array(workers.length).fill(false); } public static async create( @@ -49,28 +40,31 @@ export class UciEngine { await engine.broadcastCommands(["uci"], "uciok"); await engine.setMultiPv(engine.multiPv, true); await customEngineInit?.(engine.sendCommands.bind(engine)); - engine.ready = true; + for (const worker of workers) { + worker.isReady = true; + } + engine.isReady = true; console.log(`${engineName} initialized`); return engine; } - private acquireWorker(): { index: number; worker: EngineWorker } | undefined { - for (let i = 0; i < this.workers.length; i++) { - if (!this.isBusy[i]) { - this.isBusy[i] = true; - return { index: i, worker: this.workers[i] }; - } + private acquireWorker(): EngineWorker | undefined { + for (const worker of this.workers) { + if (!worker.isReady) continue; + + worker.isReady = false; + return worker; } return undefined; } - private releaseWorker(index: number) { - this.isBusy[index] = false; + private releaseWorker(worker: EngineWorker) { + worker.isReady = true; + const nextJob = this.workerQueue.shift(); - if (this.workerQueue.length > 0) { - const nextJob = this.workerQueue.shift()!; + if (nextJob) { this.sendCommands( nextJob.commands, nextJob.finalMessage, @@ -109,7 +103,7 @@ export class UciEngine { throw new Error(`Invalid SkillLevel value : ${skillLevel}`); } - await this.sendCommands( + await this.broadcastCommands( [`setoption name Skill Level value ${skillLevel}`, "isready"], "readyok" ); @@ -118,31 +112,35 @@ export class UciEngine { } private throwErrorIfNotReady() { - if (!this.ready) { + if (!this.isReady) { throw new Error(`${this.engineName} is not ready`); } } public shutdown(): void { - this.ready = false; + this.isReady = false; + this.workerQueue = []; for (const worker of this.workers) { worker.uci("quit"); worker.terminate?.(); + worker.isReady = false; } - this.isBusy = Array(this.workers.length).fill(false); - this.workerQueue = []; - console.log(`${this.engineName} shutdown`); } - public isReady(): boolean { - return this.ready; + public getIsReady(): boolean { + return this.isReady; } public async stopSearch(): Promise { - await this.sendCommands(["stop", "isready"], "readyok"); + this.workerQueue = []; + await this.broadcastCommands(["stop", "isready"], "readyok"); + + for (const worker of this.workers) { + this.releaseWorker(worker); + } } private async sendCommands( @@ -150,8 +148,9 @@ export class UciEngine { finalMessage: string, onNewMessage?: (messages: string[]) => void ): Promise { - const acquired = this.acquireWorker(); - if (!acquired) { + const worker = this.acquireWorker(); + + if (!worker) { return new Promise((resolve) => { this.workerQueue.push({ commands, @@ -161,21 +160,13 @@ export class UciEngine { }); }); } - return new Promise((resolve) => { - const messages: string[] = []; - acquired.worker.listen = (data) => { - messages.push(data); - onNewMessage?.(messages); - if (data.startsWith(finalMessage)) { - this.releaseWorker(acquired.index); - resolve(messages); - } - }; - for (const command of commands) { - acquired.worker.uci(command); - } - }); + return this.sendCommandsToWorker( + worker, + commands, + finalMessage, + onNewMessage + ); } private async sendCommandsToWorker( @@ -189,7 +180,9 @@ export class UciEngine { worker.listen = (data) => { messages.push(data); onNewMessage?.(messages); + if (data.startsWith(finalMessage)) { + this.releaseWorker(worker); resolve(messages); } }; @@ -199,13 +192,15 @@ export class UciEngine { }); } - private broadcastCommands( + private async broadcastCommands( commands: string[], finalMessage: string, onNewMessage?: (messages: string[]) => void - ): Promise[] { - return this.workers.map((worker) => - this.sendCommandsToWorker(worker, commands, finalMessage, onNewMessage) + ): Promise { + await Promise.all( + this.workers.map((worker) => + this.sendCommandsToWorker(worker, commands, finalMessage, onNewMessage) + ) ); } @@ -219,17 +214,16 @@ export class UciEngine { this.throwErrorIfNotReady(); setEvaluationProgress?.(1); await this.setMultiPv(multiPv); - this.ready = false; + this.isReady = false; - await this.sendCommands( - ["ucinewgame", "position startpos", "isready"], - "readyok" - ); + await this.broadcastCommands(["ucinewgame", "isready"], "readyok"); const positions: PositionEval[] = new Array(fens.length); let completed = 0; - const updateProgress = () => { + const updateEval = (index: number, positionEval: PositionEval) => { + completed++; + positions[index] = positionEval; const progress = completed / fens.length; setEvaluationProgress?.(99 - Math.exp(-4 * progress) * 99); }; @@ -238,7 +232,7 @@ export class UciEngine { fens.map(async (fen, i) => { const whoIsCheckmated = getWhoIsCheckmated(fen); if (whoIsCheckmated) { - positions[i] = { + updateEval(i, { lines: [ { pv: [], @@ -247,15 +241,13 @@ export class UciEngine { mate: whoIsCheckmated === "w" ? -1 : 1, }, ], - }; - completed++; - updateProgress(); + }); return; } const isStalemate = getIsStalemate(fen); if (isStalemate) { - positions[i] = { + updateEval(i, { lines: [ { pv: [], @@ -264,16 +256,12 @@ export class UciEngine { cp: 0, }, ], - }; - completed++; - updateProgress(); + }); return; } const result = await this.evaluatePosition(fen, depth); - positions[i] = result; - completed++; - updateProgress(); + updateEval(i, result); }) ); @@ -284,7 +272,7 @@ export class UciEngine { ); const accuracy = computeAccuracy(positions); - this.ready = true; + this.isReady = true; return { positions: positionsWithClassification, accuracy, @@ -301,14 +289,14 @@ export class UciEngine { fen: string, depth = 16 ): Promise { - console.log(`Evaluating position: ${fen}`); - - const lichessEval = await getLichessEval(fen, this.multiPv); - if ( - lichessEval.lines.length >= this.multiPv && - lichessEval.lines[0].depth >= depth - ) { - return lichessEval; + if (this.workers.length < 2) { + const lichessEval = await getLichessEval(fen, this.multiPv); + if ( + lichessEval.lines.length >= this.multiPv && + lichessEval.lines[0].depth >= depth + ) { + return lichessEval; + } } const results = await this.sendCommands( diff --git a/src/lib/engine/worker.ts b/src/lib/engine/worker.ts index 4df9862..18ab9b3 100644 --- a/src/lib/engine/worker.ts +++ b/src/lib/engine/worker.ts @@ -1,15 +1,25 @@ import { EngineWorker } from "@/types/engine"; -export const getEngineWorkers = (enginePath: string): EngineWorker[] => { +export const getEngineWorkers = ( + enginePath: string, + workersInputNb?: number +): EngineWorker[] => { + if (workersInputNb !== undefined && workersInputNb < 1) { + throw new Error( + `Number of workers must be greater than 0, got ${workersInputNb} instead` + ); + } + const engineWorkers: EngineWorker[] = []; - const instanceCount = - navigator.hardwareConcurrency - (navigator.hardwareConcurrency % 2 ? 0 : 1); + const maxWorkersNb = Math.max(1, navigator.hardwareConcurrency - 4); + const workersNb = workersInputNb ?? maxWorkersNb; - for (let i = 0; i < instanceCount; i++) { + for (let i = 0; i < workersNb; i++) { const worker = new Worker(enginePath); const engineWorker: EngineWorker = { + isReady: false, uci: (command: string) => worker.postMessage(command), listen: () => null, terminate: () => worker.terminate(), diff --git a/src/sections/analysis/hooks/useCurrentPosition.ts b/src/sections/analysis/hooks/useCurrentPosition.ts index 395e97a..4b21389 100644 --- a/src/sections/analysis/hooks/useCurrentPosition.ts +++ b/src/sections/analysis/hooks/useCurrentPosition.ts @@ -17,7 +17,7 @@ import { getMovesClassification } from "@/lib/engine/helpers/moveClassification" export const useCurrentPosition = (engineName?: EngineName) => { const [currentPosition, setCurrentPosition] = useAtom(currentPositionAtom); - const engine = useEngine(engineName); + const engine = useEngine(engineName, 1); const gameEval = useAtomValue(gameEvalAtom); const game = useAtomValue(gameAtom); const board = useAtomValue(boardAtom); @@ -52,7 +52,7 @@ export const useCurrentPosition = (engineName?: EngineName) => { if ( !position.eval && - engine?.isReady() && + engine?.getIsReady() && engineName && !board.isCheckmate() && !board.isStalemate() @@ -61,7 +61,7 @@ export const useCurrentPosition = (engineName?: EngineName) => { fen: string, setPartialEval?: (positionEval: PositionEval) => void ) => { - if (!engine?.isReady() || !engineName) + if (!engine?.getIsReady() || !engineName) throw new Error("Engine not ready"); const savedEval = savedEvals[fen]; if ( diff --git a/src/sections/analysis/panelHeader/analyzeButton.tsx b/src/sections/analysis/panelHeader/analyzeButton.tsx index 8cf56ee..7e685e5 100644 --- a/src/sections/analysis/panelHeader/analyzeButton.tsx +++ b/src/sections/analysis/panelHeader/analyzeButton.tsx @@ -30,11 +30,15 @@ export default function AnalyzeButton() { const setSavedEvals = useSetAtom(savedEvalsAtom); const readyToAnalyse = - engine?.isReady() && game.history().length > 0 && !evaluationProgress; + engine?.getIsReady() && game.history().length > 0 && !evaluationProgress; const handleAnalyze = async () => { const params = getEvaluateGameParams(game); - if (!engine?.isReady() || params.fens.length === 0 || evaluationProgress) { + if ( + !engine?.getIsReady() || + params.fens.length === 0 || + evaluationProgress + ) { return; } diff --git a/src/sections/analysis/states.ts b/src/sections/analysis/states.ts index 591bfe2..ff59887 100644 --- a/src/sections/analysis/states.ts +++ b/src/sections/analysis/states.ts @@ -13,7 +13,7 @@ export const showBestMoveArrowAtom = atom(true); export const showPlayerMoveIconAtom = atom(true); export const engineNameAtom = atom(EngineName.Stockfish17Lite); -export const engineDepthAtom = atom(16); +export const engineDepthAtom = atom(14); export const engineMultiPvAtom = atom(3); export const evaluationProgressAtom = atom(0); diff --git a/src/sections/play/board.tsx b/src/sections/play/board.tsx index a295b95..b4fe592 100644 --- a/src/sections/play/board.tsx +++ b/src/sections/play/board.tsx @@ -19,7 +19,7 @@ import { useGameData } from "@/hooks/useGameData"; export default function BoardContainer() { const screenSize = useScreenSize(); const engineName = useAtomValue(enginePlayNameAtom); - const engine = useEngine(engineName); + const engine = useEngine(engineName, 1); const game = useAtomValue(gameAtom); const playerColor = useAtomValue(playerColorAtom); const { makeMove: makeGameMove } = useChessActions(gameAtom); @@ -32,7 +32,7 @@ export default function BoardContainer() { useEffect(() => { const playEngineMove = async () => { if ( - !engine?.isReady() || + !engine?.getIsReady() || game.turn() === playerColor || isGameFinished || !isGameInProgress diff --git a/src/types/engine.ts b/src/types/engine.ts index b83f32f..4c664e3 100644 --- a/src/types/engine.ts +++ b/src/types/engine.ts @@ -1,6 +1,14 @@ export interface EngineWorker { + isReady: boolean; uci(command: string): void; listen: (data: string) => void; terminate?: () => void; setNnueBuffer?: (data: Uint8Array, index?: number) => void; } + +export interface WorkerJob { + commands: string[]; + finalMessage: string; + onNewMessage?: (messages: string[]) => void; + resolve: (messages: string[]) => void; +}