feat : add live move classification

This commit is contained in:
GuillaumeSD
2024-04-14 19:12:59 +02:00
parent 6b774e085a
commit 5379893288
5 changed files with 118 additions and 18 deletions

View File

@@ -213,7 +213,7 @@ export abstract class UciEngine {
depth = 16, depth = 16,
multiPv = this.multiPv, multiPv = this.multiPv,
setPartialEval, setPartialEval,
}: EvaluatePositionWithUpdateParams): Promise<void> { }: EvaluatePositionWithUpdateParams): Promise<PositionEval> {
this.throwErrorIfNotReady(); this.throwErrorIfNotReady();
const lichessEvalPromise = getLichessEval(fen, multiPv); const lichessEvalPromise = getLichessEval(fen, multiPv);
@@ -224,6 +224,7 @@ export abstract class UciEngine {
const whiteToPlay = fen.split(" ")[1] === "w"; const whiteToPlay = fen.split(" ")[1] === "w";
const onNewMessage = (messages: string[]) => { const onNewMessage = (messages: string[]) => {
if (!setPartialEval) return;
const parsedResults = parseEvaluationResults(messages, whiteToPlay); const parsedResults = parseEvaluationResults(messages, whiteToPlay);
setPartialEval(parsedResults); setPartialEval(parsedResults);
}; };
@@ -235,15 +236,17 @@ export abstract class UciEngine {
lichessEval.lines.length >= multiPv && lichessEval.lines.length >= multiPv &&
lichessEval.lines[0].depth >= depth lichessEval.lines[0].depth >= depth
) { ) {
setPartialEval(lichessEval); setPartialEval?.(lichessEval);
return; return lichessEval;
} }
await this.sendCommands( const results = await this.sendCommands(
[`position fen ${fen}`, `go depth ${depth}`], [`position fen ${fen}`, `go depth ${depth}`],
"bestmove", "bestmove",
onNewMessage onNewMessage
); );
return parseEvaluationResults(results, whiteToPlay);
} }
public async getEngineNextMove( public async getEngineNextMove(

View File

@@ -5,12 +5,15 @@ import {
engineMultiPvAtom, engineMultiPvAtom,
gameAtom, gameAtom,
gameEvalAtom, gameEvalAtom,
savedEvalsAtom,
} from "@/sections/analysis/states"; } from "@/sections/analysis/states";
import { CurrentPosition, PositionEval } from "@/types/eval"; import { CurrentPosition, PositionEval } from "@/types/eval";
import { useAtom, useAtomValue } from "jotai"; import { useAtom, useAtomValue } from "jotai";
import { useEffect } from "react"; import { useEffect } from "react";
import { useEngine } from "../../../hooks/useEngine"; import { useEngine } from "../../../hooks/useEngine";
import { EngineName } from "@/types/enums"; import { EngineName } from "@/types/enums";
import { getEvaluateGameParams } from "@/lib/chess";
import { getMovesClassification } from "@/lib/engine/helpers/moveClassification";
export const useCurrentPosition = (engineName?: EngineName) => { export const useCurrentPosition = (engineName?: EngineName) => {
const [currentPosition, setCurrentPosition] = useAtom(currentPositionAtom); const [currentPosition, setCurrentPosition] = useAtom(currentPositionAtom);
@@ -20,6 +23,7 @@ export const useCurrentPosition = (engineName?: EngineName) => {
const board = useAtomValue(boardAtom); const board = useAtomValue(boardAtom);
const depth = useAtomValue(engineDepthAtom); const depth = useAtomValue(engineDepthAtom);
const multiPv = useAtomValue(engineMultiPvAtom); const multiPv = useAtomValue(engineMultiPvAtom);
const [savedEvals, setSavedEvals] = useAtom(savedEvalsAtom);
useEffect(() => { useEffect(() => {
const position: CurrentPosition = { const position: CurrentPosition = {
@@ -44,21 +48,92 @@ export const useCurrentPosition = (engineName?: EngineName) => {
} }
} }
if (!position.eval && engine?.isReady()) { setCurrentPosition(position);
const setPartialEval = (positionEval: PositionEval) => {
setCurrentPosition({ ...position, eval: positionEval }); if (!position.eval && engine?.isReady() && engineName) {
const getFenEngineEval = async (
fen: string,
setPartialEval?: (positionEval: PositionEval) => void
) => {
if (!engine?.isReady() || !engineName)
throw new Error("Engine not ready");
const savedEval = savedEvals[fen];
if (
savedEval &&
savedEval.engine === engineName &&
savedEval.lines[0].depth >= depth
) {
setPartialEval?.(savedEval);
return savedEval;
}
const rawPositionEval = await engine.evaluatePositionWithUpdate({
fen,
depth,
multiPv,
setPartialEval,
});
setSavedEvals((prev) => ({
...prev,
[fen]: { ...rawPositionEval, engine: engineName },
}));
return rawPositionEval;
}; };
engine.evaluatePositionWithUpdate({ const getPositionEval = async () => {
fen: board.fen(), const setPartialEval = (positionEval: PositionEval) => {
depth, setCurrentPosition({ ...position, eval: positionEval });
multiPv, };
setPartialEval, const rawPositionEval = await getFenEngineEval(
}); board.fen(),
setPartialEval
);
if (boardHistory.length === 0) return;
const params = getEvaluateGameParams(board);
const fens = params.fens.slice(board.turn() === "w" ? -3 : -4);
const uciMoves = params.uciMoves.slice(board.turn() === "w" ? -3 : -4);
const lastRawEval = await getFenEngineEval(fens.slice(-2)[0]);
const rawPositions: PositionEval[] = fens.map((_, idx) => {
if (idx === fens.length - 2) return lastRawEval;
if (idx === fens.length - 1) return rawPositionEval;
return {
lines: [
{
pv: [],
depth: 0,
multiPv: 1,
cp: 1,
},
],
};
});
const positionsWithMoveClassification = getMovesClassification(
rawPositions,
uciMoves,
fens
);
setCurrentPosition({
...position,
eval: positionsWithMoveClassification.slice(-1)[0],
lastEval: positionsWithMoveClassification.slice(-2)[0],
});
};
getPositionEval();
} }
setCurrentPosition(position); return () => {
}, [gameEval, board, game, engine, depth, multiPv, setCurrentPosition]); engine?.stopSearch();
};
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [gameEval, board, game, engine, depth, multiPv]);
return currentPosition; return currentPosition;
}; };

View File

@@ -6,13 +6,15 @@ import {
evaluationProgressAtom, evaluationProgressAtom,
gameAtom, gameAtom,
gameEvalAtom, gameEvalAtom,
savedEvalsAtom,
} from "../states"; } from "../states";
import { useAtom, useAtomValue } from "jotai"; import { useAtom, useAtomValue, useSetAtom } from "jotai";
import { getEvaluateGameParams } from "@/lib/chess"; import { getEvaluateGameParams } from "@/lib/chess";
import { useGameDatabase } from "@/hooks/useGameDatabase"; import { useGameDatabase } from "@/hooks/useGameDatabase";
import { LoadingButton } from "@mui/lab"; import { LoadingButton } from "@mui/lab";
import { useEngine } from "@/hooks/useEngine"; import { useEngine } from "@/hooks/useEngine";
import { logAnalyticsEvent } from "@/lib/firebase"; import { logAnalyticsEvent } from "@/lib/firebase";
import { SavedEvals } from "@/types/eval";
export default function AnalyzeButton() { export default function AnalyzeButton() {
const engineName = useAtomValue(engineNameAtom); const engineName = useAtomValue(engineNameAtom);
@@ -25,6 +27,7 @@ export default function AnalyzeButton() {
const { setGameEval, gameFromUrl } = useGameDatabase(); const { setGameEval, gameFromUrl } = useGameDatabase();
const [gameEval, setEval] = useAtom(gameEvalAtom); const [gameEval, setEval] = useAtom(gameEvalAtom);
const game = useAtomValue(gameAtom); const game = useAtomValue(gameAtom);
const setSavedEvals = useSetAtom(savedEvalsAtom);
const readyToAnalyse = const readyToAnalyse =
engine?.isReady() && game.history().length > 0 && !evaluationProgress; engine?.isReady() && game.history().length > 0 && !evaluationProgress;
@@ -49,6 +52,15 @@ export default function AnalyzeButton() {
setGameEval(gameFromUrl.id, newGameEval); setGameEval(gameFromUrl.id, newGameEval);
} }
const gameSavedEvals: SavedEvals = params.fens.reduce((acc, fen, idx) => {
acc[fen] = { ...newGameEval.positions[idx], engine: engineName };
return acc;
}, {} as SavedEvals);
setSavedEvals((prev) => ({
...prev,
...gameSavedEvals,
}));
logAnalyticsEvent("analyze_game", { logAnalyticsEvent("analyze_game", {
engine: engineName, engine: engineName,
depth: engineDepth, depth: engineDepth,

View File

@@ -1,5 +1,5 @@
import { EngineName } from "@/types/enums"; import { EngineName } from "@/types/enums";
import { CurrentPosition, GameEval } from "@/types/eval"; import { CurrentPosition, GameEval, SavedEvals } from "@/types/eval";
import { Chess } from "chess.js"; import { Chess } from "chess.js";
import { atom } from "jotai"; import { atom } from "jotai";
@@ -16,3 +16,5 @@ export const engineNameAtom = atom<EngineName>(EngineName.Stockfish16);
export const engineDepthAtom = atom(16); export const engineDepthAtom = atom(16);
export const engineMultiPvAtom = atom(3); export const engineMultiPvAtom = atom(3);
export const evaluationProgressAtom = atom(0); export const evaluationProgressAtom = atom(0);
export const savedEvalsAtom = atom<SavedEvals>({});

View File

@@ -38,7 +38,7 @@ export interface EvaluatePositionWithUpdateParams {
fen: string; fen: string;
depth?: number; depth?: number;
multiPv?: number; multiPv?: number;
setPartialEval: (positionEval: PositionEval) => void; setPartialEval?: (positionEval: PositionEval) => void;
} }
export interface CurrentPosition { export interface CurrentPosition {
@@ -55,3 +55,11 @@ export interface EvaluateGameParams {
multiPv?: number; multiPv?: number;
setEvaluationProgress?: (value: number) => void; setEvaluationProgress?: (value: number) => void;
} }
export interface SavedEval {
bestMove?: string;
lines: LineEval[];
engine: EngineName;
}
export type SavedEvals = Record<string, SavedEval | undefined>;