feat: multiple parallel workers (#8)

This commit is contained in:
supertorpe
2025-04-20 01:27:42 +02:00
committed by GitHub
parent bf9769728e
commit 61c90b9c6b
6 changed files with 179 additions and 74 deletions

View File

@@ -1,12 +1,12 @@
import { EngineName } from "@/types/enums";
import { UciEngine } from "./uciEngine";
import { getEngineWorker } from "./worker";
import { getEngineWorkers } from "./worker";
export class Stockfish11 {
public static async create(): Promise<UciEngine> {
const worker = getEngineWorker("engines/stockfish-11.js");
const workers = getEngineWorkers("engines/stockfish-11.js");
return UciEngine.create(EngineName.Stockfish11, worker);
return UciEngine.create(EngineName.Stockfish11, workers);
}
public static isSupported() {

View File

@@ -1,7 +1,7 @@
import { EngineName } from "@/types/enums";
import { UciEngine } from "./uciEngine";
import { isMultiThreadSupported, isWasmSupported } from "./shared";
import { getEngineWorker } from "./worker";
import { getEngineWorkers } from "./worker";
export class Stockfish16 {
public static async create(nnue?: boolean): Promise<UciEngine> {
@@ -25,9 +25,9 @@ export class Stockfish16 {
);
};
const worker = getEngineWorker(enginePath);
const workers = getEngineWorkers(enginePath);
return UciEngine.create(EngineName.Stockfish16, worker, customEngineInit);
return UciEngine.create(EngineName.Stockfish16, workers, customEngineInit);
}
public static isSupported() {

View File

@@ -1,7 +1,7 @@
import { EngineName } from "@/types/enums";
import { UciEngine } from "./uciEngine";
import { isMultiThreadSupported, isWasmSupported } from "./shared";
import { getEngineWorker } from "./worker";
import { getEngineWorkers } from "./worker";
export class Stockfish16_1 {
public static async create(lite?: boolean): Promise<UciEngine> {
@@ -20,9 +20,9 @@ export class Stockfish16_1 {
? EngineName.Stockfish16_1Lite
: EngineName.Stockfish16_1;
const worker = getEngineWorker(enginePath);
const workers = getEngineWorkers(enginePath);
return UciEngine.create(engineName, worker);
return UciEngine.create(engineName, workers);
}
public static isSupported() {

View File

@@ -1,7 +1,7 @@
import { EngineName } from "@/types/enums";
import { UciEngine } from "./uciEngine";
import { isMultiThreadSupported, isWasmSupported } from "./shared";
import { getEngineWorker } from "./worker";
import { getEngineWorkers } from "./worker";
export class Stockfish17 {
public static async create(lite?: boolean): Promise<UciEngine> {
@@ -20,9 +20,9 @@ export class Stockfish17 {
? EngineName.Stockfish17Lite
: EngineName.Stockfish17;
const worker = getEngineWorker(enginePath);
const workers = getEngineWorkers(enginePath);
return UciEngine.create(engineName, worker);
return UciEngine.create(engineName, workers);
}
public static isSupported() {

View File

@@ -15,28 +15,38 @@ import { getLichessEval } from "../lichess";
import { getMovesClassification } from "./helpers/moveClassification";
import { EngineWorker } from "@/types/engine";
type WorkerJob = {
commands: string[];
finalMessage: string;
onNewMessage?: (messages: string[]) => void;
resolve: (messages: string[]) => void;
};
export class UciEngine {
private worker: EngineWorker;
private workers: EngineWorker[];
private isBusy: boolean[] = [];
private workerQueue: WorkerJob[] = [];
private ready = false;
private engineName: EngineName;
private multiPv = 3;
private skillLevel: number | undefined = undefined;
private constructor(engineName: EngineName, worker: EngineWorker) {
private constructor(engineName: EngineName, workers: EngineWorker[]) {
this.engineName = engineName;
this.worker = worker;
this.workers = workers;
this.isBusy = new Array(workers.length).fill(false);
}
public static async create(
engineName: EngineName,
worker: EngineWorker,
workers: EngineWorker[],
customEngineInit?: (
sendCommands: UciEngine["sendCommands"]
) => Promise<void>
): Promise<UciEngine> {
const engine = new UciEngine(engineName, worker);
const engine = new UciEngine(engineName, workers);
await engine.sendCommands(["uci"], "uciok");
await engine.broadcastCommands(["uci"], "uciok");
await engine.setMultiPv(engine.multiPv, true);
await customEngineInit?.(engine.sendCommands.bind(engine));
engine.ready = true;
@@ -45,6 +55,30 @@ export class UciEngine {
return engine;
}
private acquireWorker(): { index: number; worker: EngineWorker } | undefined {
for (let i = 0; i < this.workers.length; i++) {
if (!this.isBusy[i]) {
this.isBusy[i] = true;
return { index: i, worker: this.workers[i] };
}
}
return undefined;
}
private releaseWorker(index: number) {
this.isBusy[index] = false;
if (this.workerQueue.length > 0) {
const nextJob = this.workerQueue.shift()!;
this.sendCommands(
nextJob.commands,
nextJob.finalMessage,
nextJob.onNewMessage
).then(nextJob.resolve);
}
}
private async setMultiPv(multiPv: number, initCase = false) {
if (!initCase) {
if (multiPv === this.multiPv) return;
@@ -56,7 +90,7 @@ export class UciEngine {
throw new Error(`Invalid MultiPV value : ${multiPv}`);
}
await this.sendCommands(
await this.broadcastCommands(
[`setoption name MultiPV value ${multiPv}`, "isready"],
"readyok"
);
@@ -91,8 +125,15 @@ export class UciEngine {
public shutdown(): void {
this.ready = false;
this.worker.uci("quit");
this.worker.terminate?.();
for (const worker of this.workers) {
worker.uci("quit");
worker.terminate?.();
}
this.isBusy = Array(this.workers.length).fill(false);
this.workerQueue = [];
console.log(`${this.engineName} shutdown`);
}
@@ -109,24 +150,65 @@ export class UciEngine {
finalMessage: string,
onNewMessage?: (messages: string[]) => void
): Promise<string[]> {
const acquired = this.acquireWorker();
if (!acquired) {
return new Promise((resolve) => {
this.workerQueue.push({
commands,
finalMessage,
onNewMessage,
resolve,
});
});
}
return new Promise((resolve) => {
const messages: string[] = [];
this.worker.listen = (data) => {
acquired.worker.listen = (data) => {
messages.push(data);
onNewMessage?.(messages);
if (data.startsWith(finalMessage)) {
this.releaseWorker(acquired.index);
resolve(messages);
}
};
for (const command of commands) {
this.worker.uci(command);
acquired.worker.uci(command);
}
});
}
private async sendCommandsToWorker(
worker: EngineWorker,
commands: string[],
finalMessage: string,
onNewMessage?: (messages: string[]) => void
): Promise<string[]> {
return new Promise((resolve) => {
const messages: string[] = [];
worker.listen = (data) => {
messages.push(data);
onNewMessage?.(messages);
if (data.startsWith(finalMessage)) {
resolve(messages);
}
};
for (const command of commands) {
worker.uci(command);
}
});
}
private broadcastCommands(
commands: string[],
finalMessage: string,
onNewMessage?: (messages: string[]) => void
): Promise<string[]>[] {
return this.workers.map((worker) =>
this.sendCommandsToWorker(worker, commands, finalMessage, onNewMessage)
);
}
public async evaluateGame({
fens,
uciMoves,
@@ -139,14 +221,24 @@ export class UciEngine {
await this.setMultiPv(multiPv);
this.ready = false;
await this.sendCommands(["ucinewgame", "isready"], "readyok");
this.worker.uci("position startpos");
await this.sendCommands(
["ucinewgame", "position startpos", "isready"],
"readyok"
);
const positions: PositionEval[] = [];
for (const fen of fens) {
const positions: PositionEval[] = new Array(fens.length);
let completed = 0;
const updateProgress = () => {
const progress = completed / fens.length;
setEvaluationProgress?.(99 - Math.exp(-4 * progress) * 99);
};
await Promise.all(
fens.map(async (fen, i) => {
const whoIsCheckmated = getWhoIsCheckmated(fen);
if (whoIsCheckmated) {
positions.push({
positions[i] = {
lines: [
{
pv: [],
@@ -155,13 +247,15 @@ export class UciEngine {
mate: whoIsCheckmated === "w" ? -1 : 1,
},
],
});
continue;
};
completed++;
updateProgress();
return;
}
const isStalemate = getIsStalemate(fen);
if (isStalemate) {
positions.push({
positions[i] = {
lines: [
{
pv: [],
@@ -170,16 +264,18 @@ export class UciEngine {
cp: 0,
},
],
});
continue;
};
completed++;
updateProgress();
return;
}
const result = await this.evaluatePosition(fen, depth);
positions.push(result);
setEvaluationProgress?.(
99 - Math.exp(-4 * (fens.indexOf(fen) / fens.length)) * 99
positions[i] = result;
completed++;
updateProgress();
})
);
}
const positionsWithClassification = getMovesClassification(
positions,

View File

@@ -1,6 +1,12 @@
import { EngineWorker } from "@/types/engine";
export const getEngineWorker = (enginePath: string): EngineWorker => {
export const getEngineWorkers = (enginePath: string): EngineWorker[] => {
const engineWorkers: EngineWorker[] = [];
const instanceCount =
navigator.hardwareConcurrency - (navigator.hardwareConcurrency % 2 ? 0 : 1);
for (let i = 0; i < instanceCount; i++) {
const worker = new Worker(enginePath);
const engineWorker: EngineWorker = {
@@ -13,5 +19,8 @@ export const getEngineWorker = (enginePath: string): EngineWorker => {
engineWorker.listen(event.data);
};
return engineWorker;
engineWorkers.push(engineWorker);
}
return engineWorkers;
};