feat: multiple parallel workers (#8)

This commit is contained in:
supertorpe
2025-04-20 01:27:42 +02:00
committed by GitHub
parent bf9769728e
commit 61c90b9c6b
6 changed files with 179 additions and 74 deletions

View File

@@ -1,12 +1,12 @@
import { EngineName } from "@/types/enums"; import { EngineName } from "@/types/enums";
import { UciEngine } from "./uciEngine"; import { UciEngine } from "./uciEngine";
import { getEngineWorker } from "./worker"; import { getEngineWorkers } from "./worker";
export class Stockfish11 { export class Stockfish11 {
public static async create(): Promise<UciEngine> { public static async create(): Promise<UciEngine> {
const worker = getEngineWorker("engines/stockfish-11.js"); const workers = getEngineWorkers("engines/stockfish-11.js");
return UciEngine.create(EngineName.Stockfish11, worker); return UciEngine.create(EngineName.Stockfish11, workers);
} }
public static isSupported() { public static isSupported() {

View File

@@ -1,7 +1,7 @@
import { EngineName } from "@/types/enums"; import { EngineName } from "@/types/enums";
import { UciEngine } from "./uciEngine"; import { UciEngine } from "./uciEngine";
import { isMultiThreadSupported, isWasmSupported } from "./shared"; import { isMultiThreadSupported, isWasmSupported } from "./shared";
import { getEngineWorker } from "./worker"; import { getEngineWorkers } from "./worker";
export class Stockfish16 { export class Stockfish16 {
public static async create(nnue?: boolean): Promise<UciEngine> { public static async create(nnue?: boolean): Promise<UciEngine> {
@@ -25,9 +25,9 @@ export class Stockfish16 {
); );
}; };
const worker = getEngineWorker(enginePath); const workers = getEngineWorkers(enginePath);
return UciEngine.create(EngineName.Stockfish16, worker, customEngineInit); return UciEngine.create(EngineName.Stockfish16, workers, customEngineInit);
} }
public static isSupported() { public static isSupported() {

View File

@@ -1,7 +1,7 @@
import { EngineName } from "@/types/enums"; import { EngineName } from "@/types/enums";
import { UciEngine } from "./uciEngine"; import { UciEngine } from "./uciEngine";
import { isMultiThreadSupported, isWasmSupported } from "./shared"; import { isMultiThreadSupported, isWasmSupported } from "./shared";
import { getEngineWorker } from "./worker"; import { getEngineWorkers } from "./worker";
export class Stockfish16_1 { export class Stockfish16_1 {
public static async create(lite?: boolean): Promise<UciEngine> { public static async create(lite?: boolean): Promise<UciEngine> {
@@ -20,9 +20,9 @@ export class Stockfish16_1 {
? EngineName.Stockfish16_1Lite ? EngineName.Stockfish16_1Lite
: EngineName.Stockfish16_1; : EngineName.Stockfish16_1;
const worker = getEngineWorker(enginePath); const workers = getEngineWorkers(enginePath);
return UciEngine.create(engineName, worker); return UciEngine.create(engineName, workers);
} }
public static isSupported() { public static isSupported() {

View File

@@ -1,7 +1,7 @@
import { EngineName } from "@/types/enums"; import { EngineName } from "@/types/enums";
import { UciEngine } from "./uciEngine"; import { UciEngine } from "./uciEngine";
import { isMultiThreadSupported, isWasmSupported } from "./shared"; import { isMultiThreadSupported, isWasmSupported } from "./shared";
import { getEngineWorker } from "./worker"; import { getEngineWorkers } from "./worker";
export class Stockfish17 { export class Stockfish17 {
public static async create(lite?: boolean): Promise<UciEngine> { public static async create(lite?: boolean): Promise<UciEngine> {
@@ -20,9 +20,9 @@ export class Stockfish17 {
? EngineName.Stockfish17Lite ? EngineName.Stockfish17Lite
: EngineName.Stockfish17; : EngineName.Stockfish17;
const worker = getEngineWorker(enginePath); const workers = getEngineWorkers(enginePath);
return UciEngine.create(engineName, worker); return UciEngine.create(engineName, workers);
} }
public static isSupported() { public static isSupported() {

View File

@@ -15,28 +15,38 @@ import { getLichessEval } from "../lichess";
import { getMovesClassification } from "./helpers/moveClassification"; import { getMovesClassification } from "./helpers/moveClassification";
import { EngineWorker } from "@/types/engine"; import { EngineWorker } from "@/types/engine";
type WorkerJob = {
commands: string[];
finalMessage: string;
onNewMessage?: (messages: string[]) => void;
resolve: (messages: string[]) => void;
};
export class UciEngine { export class UciEngine {
private worker: EngineWorker; private workers: EngineWorker[];
private isBusy: boolean[] = [];
private workerQueue: WorkerJob[] = [];
private ready = false; private ready = false;
private engineName: EngineName; private engineName: EngineName;
private multiPv = 3; private multiPv = 3;
private skillLevel: number | undefined = undefined; private skillLevel: number | undefined = undefined;
private constructor(engineName: EngineName, worker: EngineWorker) { private constructor(engineName: EngineName, workers: EngineWorker[]) {
this.engineName = engineName; this.engineName = engineName;
this.worker = worker; this.workers = workers;
this.isBusy = new Array(workers.length).fill(false);
} }
public static async create( public static async create(
engineName: EngineName, engineName: EngineName,
worker: EngineWorker, workers: EngineWorker[],
customEngineInit?: ( customEngineInit?: (
sendCommands: UciEngine["sendCommands"] sendCommands: UciEngine["sendCommands"]
) => Promise<void> ) => Promise<void>
): Promise<UciEngine> { ): Promise<UciEngine> {
const engine = new UciEngine(engineName, worker); const engine = new UciEngine(engineName, workers);
await engine.sendCommands(["uci"], "uciok"); await engine.broadcastCommands(["uci"], "uciok");
await engine.setMultiPv(engine.multiPv, true); await engine.setMultiPv(engine.multiPv, true);
await customEngineInit?.(engine.sendCommands.bind(engine)); await customEngineInit?.(engine.sendCommands.bind(engine));
engine.ready = true; engine.ready = true;
@@ -45,6 +55,30 @@ export class UciEngine {
return engine; return engine;
} }
private acquireWorker(): { index: number; worker: EngineWorker } | undefined {
for (let i = 0; i < this.workers.length; i++) {
if (!this.isBusy[i]) {
this.isBusy[i] = true;
return { index: i, worker: this.workers[i] };
}
}
return undefined;
}
private releaseWorker(index: number) {
this.isBusy[index] = false;
if (this.workerQueue.length > 0) {
const nextJob = this.workerQueue.shift()!;
this.sendCommands(
nextJob.commands,
nextJob.finalMessage,
nextJob.onNewMessage
).then(nextJob.resolve);
}
}
private async setMultiPv(multiPv: number, initCase = false) { private async setMultiPv(multiPv: number, initCase = false) {
if (!initCase) { if (!initCase) {
if (multiPv === this.multiPv) return; if (multiPv === this.multiPv) return;
@@ -56,7 +90,7 @@ export class UciEngine {
throw new Error(`Invalid MultiPV value : ${multiPv}`); throw new Error(`Invalid MultiPV value : ${multiPv}`);
} }
await this.sendCommands( await this.broadcastCommands(
[`setoption name MultiPV value ${multiPv}`, "isready"], [`setoption name MultiPV value ${multiPv}`, "isready"],
"readyok" "readyok"
); );
@@ -91,8 +125,15 @@ export class UciEngine {
public shutdown(): void { public shutdown(): void {
this.ready = false; this.ready = false;
this.worker.uci("quit");
this.worker.terminate?.(); for (const worker of this.workers) {
worker.uci("quit");
worker.terminate?.();
}
this.isBusy = Array(this.workers.length).fill(false);
this.workerQueue = [];
console.log(`${this.engineName} shutdown`); console.log(`${this.engineName} shutdown`);
} }
@@ -109,24 +150,65 @@ export class UciEngine {
finalMessage: string, finalMessage: string,
onNewMessage?: (messages: string[]) => void onNewMessage?: (messages: string[]) => void
): Promise<string[]> { ): Promise<string[]> {
const acquired = this.acquireWorker();
if (!acquired) {
return new Promise((resolve) => {
this.workerQueue.push({
commands,
finalMessage,
onNewMessage,
resolve,
});
});
}
return new Promise((resolve) => { return new Promise((resolve) => {
const messages: string[] = []; const messages: string[] = [];
acquired.worker.listen = (data) => {
this.worker.listen = (data) => {
messages.push(data); messages.push(data);
onNewMessage?.(messages); onNewMessage?.(messages);
if (data.startsWith(finalMessage)) { if (data.startsWith(finalMessage)) {
this.releaseWorker(acquired.index);
resolve(messages); resolve(messages);
} }
}; };
for (const command of commands) { for (const command of commands) {
this.worker.uci(command); acquired.worker.uci(command);
} }
}); });
} }
private async sendCommandsToWorker(
worker: EngineWorker,
commands: string[],
finalMessage: string,
onNewMessage?: (messages: string[]) => void
): Promise<string[]> {
return new Promise((resolve) => {
const messages: string[] = [];
worker.listen = (data) => {
messages.push(data);
onNewMessage?.(messages);
if (data.startsWith(finalMessage)) {
resolve(messages);
}
};
for (const command of commands) {
worker.uci(command);
}
});
}
private broadcastCommands(
commands: string[],
finalMessage: string,
onNewMessage?: (messages: string[]) => void
): Promise<string[]>[] {
return this.workers.map((worker) =>
this.sendCommandsToWorker(worker, commands, finalMessage, onNewMessage)
);
}
public async evaluateGame({ public async evaluateGame({
fens, fens,
uciMoves, uciMoves,
@@ -139,14 +221,24 @@ export class UciEngine {
await this.setMultiPv(multiPv); await this.setMultiPv(multiPv);
this.ready = false; this.ready = false;
await this.sendCommands(["ucinewgame", "isready"], "readyok"); await this.sendCommands(
this.worker.uci("position startpos"); ["ucinewgame", "position startpos", "isready"],
"readyok"
);
const positions: PositionEval[] = []; const positions: PositionEval[] = new Array(fens.length);
for (const fen of fens) { let completed = 0;
const updateProgress = () => {
const progress = completed / fens.length;
setEvaluationProgress?.(99 - Math.exp(-4 * progress) * 99);
};
await Promise.all(
fens.map(async (fen, i) => {
const whoIsCheckmated = getWhoIsCheckmated(fen); const whoIsCheckmated = getWhoIsCheckmated(fen);
if (whoIsCheckmated) { if (whoIsCheckmated) {
positions.push({ positions[i] = {
lines: [ lines: [
{ {
pv: [], pv: [],
@@ -155,13 +247,15 @@ export class UciEngine {
mate: whoIsCheckmated === "w" ? -1 : 1, mate: whoIsCheckmated === "w" ? -1 : 1,
}, },
], ],
}); };
continue; completed++;
updateProgress();
return;
} }
const isStalemate = getIsStalemate(fen); const isStalemate = getIsStalemate(fen);
if (isStalemate) { if (isStalemate) {
positions.push({ positions[i] = {
lines: [ lines: [
{ {
pv: [], pv: [],
@@ -170,16 +264,18 @@ export class UciEngine {
cp: 0, cp: 0,
}, },
], ],
}); };
continue; completed++;
updateProgress();
return;
} }
const result = await this.evaluatePosition(fen, depth); const result = await this.evaluatePosition(fen, depth);
positions.push(result); positions[i] = result;
setEvaluationProgress?.( completed++;
99 - Math.exp(-4 * (fens.indexOf(fen) / fens.length)) * 99 updateProgress();
})
); );
}
const positionsWithClassification = getMovesClassification( const positionsWithClassification = getMovesClassification(
positions, positions,

View File

@@ -1,6 +1,12 @@
import { EngineWorker } from "@/types/engine"; import { EngineWorker } from "@/types/engine";
export const getEngineWorker = (enginePath: string): EngineWorker => { export const getEngineWorkers = (enginePath: string): EngineWorker[] => {
const engineWorkers: EngineWorker[] = [];
const instanceCount =
navigator.hardwareConcurrency - (navigator.hardwareConcurrency % 2 ? 0 : 1);
for (let i = 0; i < instanceCount; i++) {
const worker = new Worker(enginePath); const worker = new Worker(enginePath);
const engineWorker: EngineWorker = { const engineWorker: EngineWorker = {
@@ -13,5 +19,8 @@ export const getEngineWorker = (enginePath: string): EngineWorker => {
engineWorker.listen(event.data); engineWorker.listen(event.data);
}; };
return engineWorker; engineWorkers.push(engineWorker);
}
return engineWorkers;
}; };