fix : engine parallel workers

This commit is contained in:
GuillaumeSD
2025-04-20 03:49:15 +02:00
parent 61c90b9c6b
commit 3e9523c49f
12 changed files with 137 additions and 108 deletions

View File

@@ -7,7 +7,10 @@ import { UciEngine } from "@/lib/engine/uciEngine";
import { EngineName } from "@/types/enums"; import { EngineName } from "@/types/enums";
import { useEffect, useState } from "react"; import { useEffect, useState } from "react";
export const useEngine = (engineName: EngineName | undefined) => { export const useEngine = (
engineName: EngineName | undefined,
workersNb?: number
) => {
const [engine, setEngine] = useState<UciEngine | null>(null); const [engine, setEngine] = useState<UciEngine | null>(null);
useEffect(() => { useEffect(() => {
@@ -17,7 +20,7 @@ export const useEngine = (engineName: EngineName | undefined) => {
return; return;
} }
pickEngine(engineName).then((newEngine) => { pickEngine(engineName, workersNb).then((newEngine) => {
setEngine((prev) => { setEngine((prev) => {
prev?.shutdown(); prev?.shutdown();
return newEngine; return newEngine;
@@ -28,21 +31,24 @@ export const useEngine = (engineName: EngineName | undefined) => {
return engine; return engine;
}; };
const pickEngine = (engine: EngineName): Promise<UciEngine> => { const pickEngine = (
engine: EngineName,
workersNb?: number
): Promise<UciEngine> => {
switch (engine) { switch (engine) {
case EngineName.Stockfish17: case EngineName.Stockfish17:
return Stockfish17.create(false); return Stockfish17.create(false, workersNb);
case EngineName.Stockfish17Lite: case EngineName.Stockfish17Lite:
return Stockfish17.create(true); return Stockfish17.create(true, workersNb);
case EngineName.Stockfish16_1: case EngineName.Stockfish16_1:
return Stockfish16_1.create(false); return Stockfish16_1.create(false, workersNb);
case EngineName.Stockfish16_1Lite: case EngineName.Stockfish16_1Lite:
return Stockfish16_1.create(true); return Stockfish16_1.create(true, workersNb);
case EngineName.Stockfish16: case EngineName.Stockfish16:
return Stockfish16.create(false); return Stockfish16.create(false, workersNb);
case EngineName.Stockfish16NNUE: case EngineName.Stockfish16NNUE:
return Stockfish16.create(true); return Stockfish16.create(true, workersNb);
case EngineName.Stockfish11: case EngineName.Stockfish11:
return Stockfish11.create(); return Stockfish11.create(workersNb);
} }
}; };

View File

@@ -3,8 +3,8 @@ import { UciEngine } from "./uciEngine";
import { getEngineWorkers } from "./worker"; import { getEngineWorkers } from "./worker";
export class Stockfish11 { export class Stockfish11 {
public static async create(): Promise<UciEngine> { public static async create(workersNb?: number): Promise<UciEngine> {
const workers = getEngineWorkers("engines/stockfish-11.js"); const workers = getEngineWorkers("engines/stockfish-11.js", workersNb);
return UciEngine.create(EngineName.Stockfish11, workers); return UciEngine.create(EngineName.Stockfish11, workers);
} }

View File

@@ -4,7 +4,10 @@ import { isMultiThreadSupported, isWasmSupported } from "./shared";
import { getEngineWorkers } from "./worker"; import { getEngineWorkers } from "./worker";
export class Stockfish16 { export class Stockfish16 {
public static async create(nnue?: boolean): Promise<UciEngine> { public static async create(
nnue?: boolean,
workersNb?: number
): Promise<UciEngine> {
if (!Stockfish16.isSupported()) { if (!Stockfish16.isSupported()) {
throw new Error("Stockfish 16 is not supported"); throw new Error("Stockfish 16 is not supported");
} }
@@ -25,9 +28,13 @@ export class Stockfish16 {
); );
}; };
const workers = getEngineWorkers(enginePath); const engineName = nnue
? EngineName.Stockfish16NNUE
: EngineName.Stockfish16;
return UciEngine.create(EngineName.Stockfish16, workers, customEngineInit); const workers = getEngineWorkers(enginePath, workersNb);
return UciEngine.create(engineName, workers, customEngineInit);
} }
public static isSupported() { public static isSupported() {

View File

@@ -4,7 +4,10 @@ import { isMultiThreadSupported, isWasmSupported } from "./shared";
import { getEngineWorkers } from "./worker"; import { getEngineWorkers } from "./worker";
export class Stockfish16_1 { export class Stockfish16_1 {
public static async create(lite?: boolean): Promise<UciEngine> { public static async create(
lite?: boolean,
workersNb?: number
): Promise<UciEngine> {
if (!Stockfish16_1.isSupported()) { if (!Stockfish16_1.isSupported()) {
throw new Error("Stockfish 16.1 is not supported"); throw new Error("Stockfish 16.1 is not supported");
} }
@@ -20,7 +23,7 @@ export class Stockfish16_1 {
? EngineName.Stockfish16_1Lite ? EngineName.Stockfish16_1Lite
: EngineName.Stockfish16_1; : EngineName.Stockfish16_1;
const workers = getEngineWorkers(enginePath); const workers = getEngineWorkers(enginePath, workersNb);
return UciEngine.create(engineName, workers); return UciEngine.create(engineName, workers);
} }

View File

@@ -4,7 +4,10 @@ import { isMultiThreadSupported, isWasmSupported } from "./shared";
import { getEngineWorkers } from "./worker"; import { getEngineWorkers } from "./worker";
export class Stockfish17 { export class Stockfish17 {
public static async create(lite?: boolean): Promise<UciEngine> { public static async create(
lite?: boolean,
workersNb?: number
): Promise<UciEngine> {
if (!Stockfish17.isSupported()) { if (!Stockfish17.isSupported()) {
throw new Error("Stockfish 17 is not supported"); throw new Error("Stockfish 17 is not supported");
} }
@@ -20,7 +23,7 @@ export class Stockfish17 {
? EngineName.Stockfish17Lite ? EngineName.Stockfish17Lite
: EngineName.Stockfish17; : EngineName.Stockfish17;
const workers = getEngineWorkers(enginePath); const workers = getEngineWorkers(enginePath, workersNb);
return UciEngine.create(engineName, workers); return UciEngine.create(engineName, workers);
} }

View File

@@ -13,20 +13,12 @@ import { computeAccuracy } from "./helpers/accuracy";
import { getIsStalemate, getWhoIsCheckmated } from "../chess"; import { getIsStalemate, getWhoIsCheckmated } from "../chess";
import { getLichessEval } from "../lichess"; import { getLichessEval } from "../lichess";
import { getMovesClassification } from "./helpers/moveClassification"; import { getMovesClassification } from "./helpers/moveClassification";
import { EngineWorker } from "@/types/engine"; import { EngineWorker, WorkerJob } from "@/types/engine";
type WorkerJob = {
commands: string[];
finalMessage: string;
onNewMessage?: (messages: string[]) => void;
resolve: (messages: string[]) => void;
};
export class UciEngine { export class UciEngine {
private workers: EngineWorker[]; private workers: EngineWorker[];
private isBusy: boolean[] = [];
private workerQueue: WorkerJob[] = []; private workerQueue: WorkerJob[] = [];
private ready = false; private isReady = false;
private engineName: EngineName; private engineName: EngineName;
private multiPv = 3; private multiPv = 3;
private skillLevel: number | undefined = undefined; private skillLevel: number | undefined = undefined;
@@ -34,7 +26,6 @@ export class UciEngine {
private constructor(engineName: EngineName, workers: EngineWorker[]) { private constructor(engineName: EngineName, workers: EngineWorker[]) {
this.engineName = engineName; this.engineName = engineName;
this.workers = workers; this.workers = workers;
this.isBusy = new Array(workers.length).fill(false);
} }
public static async create( public static async create(
@@ -49,28 +40,31 @@ export class UciEngine {
await engine.broadcastCommands(["uci"], "uciok"); await engine.broadcastCommands(["uci"], "uciok");
await engine.setMultiPv(engine.multiPv, true); await engine.setMultiPv(engine.multiPv, true);
await customEngineInit?.(engine.sendCommands.bind(engine)); await customEngineInit?.(engine.sendCommands.bind(engine));
engine.ready = true; for (const worker of workers) {
worker.isReady = true;
}
engine.isReady = true;
console.log(`${engineName} initialized`); console.log(`${engineName} initialized`);
return engine; return engine;
} }
private acquireWorker(): { index: number; worker: EngineWorker } | undefined { private acquireWorker(): EngineWorker | undefined {
for (let i = 0; i < this.workers.length; i++) { for (const worker of this.workers) {
if (!this.isBusy[i]) { if (!worker.isReady) continue;
this.isBusy[i] = true;
return { index: i, worker: this.workers[i] }; worker.isReady = false;
} return worker;
} }
return undefined; return undefined;
} }
private releaseWorker(index: number) { private releaseWorker(worker: EngineWorker) {
this.isBusy[index] = false; worker.isReady = true;
const nextJob = this.workerQueue.shift();
if (this.workerQueue.length > 0) { if (nextJob) {
const nextJob = this.workerQueue.shift()!;
this.sendCommands( this.sendCommands(
nextJob.commands, nextJob.commands,
nextJob.finalMessage, nextJob.finalMessage,
@@ -109,7 +103,7 @@ export class UciEngine {
throw new Error(`Invalid SkillLevel value : ${skillLevel}`); throw new Error(`Invalid SkillLevel value : ${skillLevel}`);
} }
await this.sendCommands( await this.broadcastCommands(
[`setoption name Skill Level value ${skillLevel}`, "isready"], [`setoption name Skill Level value ${skillLevel}`, "isready"],
"readyok" "readyok"
); );
@@ -118,31 +112,35 @@ export class UciEngine {
} }
private throwErrorIfNotReady() { private throwErrorIfNotReady() {
if (!this.ready) { if (!this.isReady) {
throw new Error(`${this.engineName} is not ready`); throw new Error(`${this.engineName} is not ready`);
} }
} }
public shutdown(): void { public shutdown(): void {
this.ready = false; this.isReady = false;
this.workerQueue = [];
for (const worker of this.workers) { for (const worker of this.workers) {
worker.uci("quit"); worker.uci("quit");
worker.terminate?.(); worker.terminate?.();
worker.isReady = false;
} }
this.isBusy = Array(this.workers.length).fill(false);
this.workerQueue = [];
console.log(`${this.engineName} shutdown`); console.log(`${this.engineName} shutdown`);
} }
public isReady(): boolean { public getIsReady(): boolean {
return this.ready; return this.isReady;
} }
public async stopSearch(): Promise<void> { public async stopSearch(): Promise<void> {
await this.sendCommands(["stop", "isready"], "readyok"); this.workerQueue = [];
await this.broadcastCommands(["stop", "isready"], "readyok");
for (const worker of this.workers) {
this.releaseWorker(worker);
}
} }
private async sendCommands( private async sendCommands(
@@ -150,8 +148,9 @@ export class UciEngine {
finalMessage: string, finalMessage: string,
onNewMessage?: (messages: string[]) => void onNewMessage?: (messages: string[]) => void
): Promise<string[]> { ): Promise<string[]> {
const acquired = this.acquireWorker(); const worker = this.acquireWorker();
if (!acquired) {
if (!worker) {
return new Promise((resolve) => { return new Promise((resolve) => {
this.workerQueue.push({ this.workerQueue.push({
commands, commands,
@@ -161,21 +160,13 @@ export class UciEngine {
}); });
}); });
} }
return new Promise((resolve) => {
const messages: string[] = [];
acquired.worker.listen = (data) => {
messages.push(data);
onNewMessage?.(messages);
if (data.startsWith(finalMessage)) { return this.sendCommandsToWorker(
this.releaseWorker(acquired.index); worker,
resolve(messages); commands,
} finalMessage,
}; onNewMessage
for (const command of commands) { );
acquired.worker.uci(command);
}
});
} }
private async sendCommandsToWorker( private async sendCommandsToWorker(
@@ -189,7 +180,9 @@ export class UciEngine {
worker.listen = (data) => { worker.listen = (data) => {
messages.push(data); messages.push(data);
onNewMessage?.(messages); onNewMessage?.(messages);
if (data.startsWith(finalMessage)) { if (data.startsWith(finalMessage)) {
this.releaseWorker(worker);
resolve(messages); resolve(messages);
} }
}; };
@@ -199,13 +192,15 @@ export class UciEngine {
}); });
} }
private broadcastCommands( private async broadcastCommands(
commands: string[], commands: string[],
finalMessage: string, finalMessage: string,
onNewMessage?: (messages: string[]) => void onNewMessage?: (messages: string[]) => void
): Promise<string[]>[] { ): Promise<void> {
return this.workers.map((worker) => await Promise.all(
this.sendCommandsToWorker(worker, commands, finalMessage, onNewMessage) this.workers.map((worker) =>
this.sendCommandsToWorker(worker, commands, finalMessage, onNewMessage)
)
); );
} }
@@ -219,17 +214,16 @@ export class UciEngine {
this.throwErrorIfNotReady(); this.throwErrorIfNotReady();
setEvaluationProgress?.(1); setEvaluationProgress?.(1);
await this.setMultiPv(multiPv); await this.setMultiPv(multiPv);
this.ready = false; this.isReady = false;
await this.sendCommands( await this.broadcastCommands(["ucinewgame", "isready"], "readyok");
["ucinewgame", "position startpos", "isready"],
"readyok"
);
const positions: PositionEval[] = new Array(fens.length); const positions: PositionEval[] = new Array(fens.length);
let completed = 0; let completed = 0;
const updateProgress = () => { const updateEval = (index: number, positionEval: PositionEval) => {
completed++;
positions[index] = positionEval;
const progress = completed / fens.length; const progress = completed / fens.length;
setEvaluationProgress?.(99 - Math.exp(-4 * progress) * 99); setEvaluationProgress?.(99 - Math.exp(-4 * progress) * 99);
}; };
@@ -238,7 +232,7 @@ export class UciEngine {
fens.map(async (fen, i) => { fens.map(async (fen, i) => {
const whoIsCheckmated = getWhoIsCheckmated(fen); const whoIsCheckmated = getWhoIsCheckmated(fen);
if (whoIsCheckmated) { if (whoIsCheckmated) {
positions[i] = { updateEval(i, {
lines: [ lines: [
{ {
pv: [], pv: [],
@@ -247,15 +241,13 @@ export class UciEngine {
mate: whoIsCheckmated === "w" ? -1 : 1, mate: whoIsCheckmated === "w" ? -1 : 1,
}, },
], ],
}; });
completed++;
updateProgress();
return; return;
} }
const isStalemate = getIsStalemate(fen); const isStalemate = getIsStalemate(fen);
if (isStalemate) { if (isStalemate) {
positions[i] = { updateEval(i, {
lines: [ lines: [
{ {
pv: [], pv: [],
@@ -264,16 +256,12 @@ export class UciEngine {
cp: 0, cp: 0,
}, },
], ],
}; });
completed++;
updateProgress();
return; return;
} }
const result = await this.evaluatePosition(fen, depth); const result = await this.evaluatePosition(fen, depth);
positions[i] = result; updateEval(i, result);
completed++;
updateProgress();
}) })
); );
@@ -284,7 +272,7 @@ export class UciEngine {
); );
const accuracy = computeAccuracy(positions); const accuracy = computeAccuracy(positions);
this.ready = true; this.isReady = true;
return { return {
positions: positionsWithClassification, positions: positionsWithClassification,
accuracy, accuracy,
@@ -301,14 +289,14 @@ export class UciEngine {
fen: string, fen: string,
depth = 16 depth = 16
): Promise<PositionEval> { ): Promise<PositionEval> {
console.log(`Evaluating position: ${fen}`); if (this.workers.length < 2) {
const lichessEval = await getLichessEval(fen, this.multiPv);
const lichessEval = await getLichessEval(fen, this.multiPv); if (
if ( lichessEval.lines.length >= this.multiPv &&
lichessEval.lines.length >= this.multiPv && lichessEval.lines[0].depth >= depth
lichessEval.lines[0].depth >= depth ) {
) { return lichessEval;
return lichessEval; }
} }
const results = await this.sendCommands( const results = await this.sendCommands(

View File

@@ -1,15 +1,25 @@
import { EngineWorker } from "@/types/engine"; import { EngineWorker } from "@/types/engine";
export const getEngineWorkers = (enginePath: string): EngineWorker[] => { export const getEngineWorkers = (
enginePath: string,
workersInputNb?: number
): EngineWorker[] => {
if (workersInputNb !== undefined && workersInputNb < 1) {
throw new Error(
`Number of workers must be greater than 0, got ${workersInputNb} instead`
);
}
const engineWorkers: EngineWorker[] = []; const engineWorkers: EngineWorker[] = [];
const instanceCount = const maxWorkersNb = Math.max(1, navigator.hardwareConcurrency - 4);
navigator.hardwareConcurrency - (navigator.hardwareConcurrency % 2 ? 0 : 1); const workersNb = workersInputNb ?? maxWorkersNb;
for (let i = 0; i < instanceCount; i++) { for (let i = 0; i < workersNb; i++) {
const worker = new Worker(enginePath); const worker = new Worker(enginePath);
const engineWorker: EngineWorker = { const engineWorker: EngineWorker = {
isReady: false,
uci: (command: string) => worker.postMessage(command), uci: (command: string) => worker.postMessage(command),
listen: () => null, listen: () => null,
terminate: () => worker.terminate(), terminate: () => worker.terminate(),

View File

@@ -17,7 +17,7 @@ import { getMovesClassification } from "@/lib/engine/helpers/moveClassification"
export const useCurrentPosition = (engineName?: EngineName) => { export const useCurrentPosition = (engineName?: EngineName) => {
const [currentPosition, setCurrentPosition] = useAtom(currentPositionAtom); const [currentPosition, setCurrentPosition] = useAtom(currentPositionAtom);
const engine = useEngine(engineName); const engine = useEngine(engineName, 1);
const gameEval = useAtomValue(gameEvalAtom); const gameEval = useAtomValue(gameEvalAtom);
const game = useAtomValue(gameAtom); const game = useAtomValue(gameAtom);
const board = useAtomValue(boardAtom); const board = useAtomValue(boardAtom);
@@ -52,7 +52,7 @@ export const useCurrentPosition = (engineName?: EngineName) => {
if ( if (
!position.eval && !position.eval &&
engine?.isReady() && engine?.getIsReady() &&
engineName && engineName &&
!board.isCheckmate() && !board.isCheckmate() &&
!board.isStalemate() !board.isStalemate()
@@ -61,7 +61,7 @@ export const useCurrentPosition = (engineName?: EngineName) => {
fen: string, fen: string,
setPartialEval?: (positionEval: PositionEval) => void setPartialEval?: (positionEval: PositionEval) => void
) => { ) => {
if (!engine?.isReady() || !engineName) if (!engine?.getIsReady() || !engineName)
throw new Error("Engine not ready"); throw new Error("Engine not ready");
const savedEval = savedEvals[fen]; const savedEval = savedEvals[fen];
if ( if (

View File

@@ -30,11 +30,15 @@ export default function AnalyzeButton() {
const setSavedEvals = useSetAtom(savedEvalsAtom); const setSavedEvals = useSetAtom(savedEvalsAtom);
const readyToAnalyse = const readyToAnalyse =
engine?.isReady() && game.history().length > 0 && !evaluationProgress; engine?.getIsReady() && game.history().length > 0 && !evaluationProgress;
const handleAnalyze = async () => { const handleAnalyze = async () => {
const params = getEvaluateGameParams(game); const params = getEvaluateGameParams(game);
if (!engine?.isReady() || params.fens.length === 0 || evaluationProgress) { if (
!engine?.getIsReady() ||
params.fens.length === 0 ||
evaluationProgress
) {
return; return;
} }

View File

@@ -13,7 +13,7 @@ export const showBestMoveArrowAtom = atom(true);
export const showPlayerMoveIconAtom = atom(true); export const showPlayerMoveIconAtom = atom(true);
export const engineNameAtom = atom<EngineName>(EngineName.Stockfish17Lite); export const engineNameAtom = atom<EngineName>(EngineName.Stockfish17Lite);
export const engineDepthAtom = atom(16); export const engineDepthAtom = atom(14);
export const engineMultiPvAtom = atom(3); export const engineMultiPvAtom = atom(3);
export const evaluationProgressAtom = atom(0); export const evaluationProgressAtom = atom(0);

View File

@@ -19,7 +19,7 @@ import { useGameData } from "@/hooks/useGameData";
export default function BoardContainer() { export default function BoardContainer() {
const screenSize = useScreenSize(); const screenSize = useScreenSize();
const engineName = useAtomValue(enginePlayNameAtom); const engineName = useAtomValue(enginePlayNameAtom);
const engine = useEngine(engineName); const engine = useEngine(engineName, 1);
const game = useAtomValue(gameAtom); const game = useAtomValue(gameAtom);
const playerColor = useAtomValue(playerColorAtom); const playerColor = useAtomValue(playerColorAtom);
const { makeMove: makeGameMove } = useChessActions(gameAtom); const { makeMove: makeGameMove } = useChessActions(gameAtom);
@@ -32,7 +32,7 @@ export default function BoardContainer() {
useEffect(() => { useEffect(() => {
const playEngineMove = async () => { const playEngineMove = async () => {
if ( if (
!engine?.isReady() || !engine?.getIsReady() ||
game.turn() === playerColor || game.turn() === playerColor ||
isGameFinished || isGameFinished ||
!isGameInProgress !isGameInProgress

View File

@@ -1,6 +1,14 @@
export interface EngineWorker { export interface EngineWorker {
isReady: boolean;
uci(command: string): void; uci(command: string): void;
listen: (data: string) => void; listen: (data: string) => void;
terminate?: () => void; terminate?: () => void;
setNnueBuffer?: (data: Uint8Array, index?: number) => void; setNnueBuffer?: (data: Uint8Array, index?: number) => void;
} }
export interface WorkerJob {
commands: string[];
finalMessage: string;
onNewMessage?: (messages: string[]) => void;
resolve: (messages: string[]) => void;
}