mirror of
https://github.com/yacy/yacy_search_server.git
synced 2025-07-01 05:46:07 -04:00
This is the one-file "toy" transformer from https://github.com/mukel/llama3.java in an enhanced form:
427 lines
20 KiB
Java
427 lines
20 KiB
Java
/**
|
|
* Llama.java
|
|
|
|
* This file was extracted from the llama3/qwen2 projects
|
|
* https://github.com/mukel/llama3.java
|
|
* https://github.com/mukel/qwen2.svm.java
|
|
*
|
|
* License: MIT License
|
|
*
|
|
* Copyright (c) 2024 Andrej Karpathy (for llama2.c)
|
|
* Copyright (c) 2024 Alfonso² Peterssen (for llama3/qwen2)
|
|
* Copyright (c) 2023 Georgi Gerganov et al. (for llama.cpp)
|
|
* Copyright (c) 2025 Michael Peter Christen for modifications:
|
|
* The code was modified to fit the YaCy AI project:
|
|
* - back-port to Java 11 (removal of Vector API operations and record types)
|
|
* - removal of interactive mode and system.out printing
|
|
* - separation of the classes in the single java and refactoring
|
|
* - run-time performance optimizations for dot product computation of quantized values
|
|
* - joining of llama3/qwen2 into one code base; multi-arch options
|
|
* - alignment with code from https://github.com/ggml-org/llama.cpp/
|
|
*/
|
|
|
|
package net.yacy.ai.llama3;
|
|
|
|
import java.nio.FloatBuffer;
|
|
import java.util.*;
|
|
import java.util.function.IntConsumer;
|
|
import java.util.stream.IntStream;
|
|
import java.util.stream.Stream;
|
|
|
|
import net.yacy.ai.llama3.Model.Arch;
|
|
import net.yacy.ai.llama3.Model.Tokenizer;
|
|
import net.yacy.ai.llama3.Tensor.ArrayFloatTensor;
|
|
import net.yacy.ai.llama3.Tensor.FloatTensor;
|
|
|
|
public final class Llama {
|
|
|
|
private final Configuration configuration;
|
|
private final Tokenizer tokenizer;
|
|
private final Weights weights;
|
|
|
|
public Llama(Configuration configuration, Tokenizer tokenizer, Weights weights) {
|
|
this.configuration = configuration;
|
|
this.tokenizer = tokenizer;
|
|
this.weights = weights;
|
|
}
|
|
|
|
public Configuration configuration() {
|
|
return configuration;
|
|
}
|
|
|
|
public Tokenizer tokenizer() {
|
|
return tokenizer;
|
|
}
|
|
|
|
public Weights weights() {
|
|
return weights;
|
|
}
|
|
|
|
public State createNewState(int batchsize) {
|
|
State state = new State(configuration(), batchsize);
|
|
state.latestToken = tokenizer().getSpecialTokens().get("<|begin_of_text|>");
|
|
return state;
|
|
}
|
|
|
|
public static final class Configuration {
|
|
public final Arch arch; // architecture
|
|
public final int dim; // transformer dimension
|
|
public final int hiddenDim; // for ffn layers
|
|
public final int numberOfLayers; // number of layers
|
|
public final int numberOfHeads; // number of query heads
|
|
public final int numberOfKeyValueHeads; // number of key/value heads (can be < query heads because of multiquery)
|
|
public final int vocabularySize; // vocabulary size, usually 256 (byte-level)
|
|
public final int contextLength; // max sequence length
|
|
public final boolean sharedWeights;
|
|
public final float rmsNormEps;
|
|
public final float ropeTheta;
|
|
public final int headSize;
|
|
|
|
public Configuration(Arch arch, int dim, int hiddenDim, int numberOfLayers, int numberOfHeads, int numberOfKeyValueHeads,
|
|
int vocabularySize, int contextLength, boolean sharedWeights, float rmsNormEps, float ropeTheta) {
|
|
this.arch = arch;
|
|
this.dim = dim;
|
|
this.hiddenDim = hiddenDim;
|
|
this.numberOfLayers = numberOfLayers;
|
|
this.numberOfHeads = numberOfHeads;
|
|
this.numberOfKeyValueHeads = numberOfKeyValueHeads;
|
|
this.vocabularySize = vocabularySize;
|
|
this.contextLength = contextLength;
|
|
this.sharedWeights = sharedWeights;
|
|
this.rmsNormEps = rmsNormEps;
|
|
this.ropeTheta = ropeTheta;
|
|
this.headSize = dim / numberOfHeads;
|
|
}
|
|
|
|
public Configuration withContextLength(int newContextLength) {
|
|
if (newContextLength < 0) {
|
|
return this; // no change
|
|
}
|
|
return new Configuration(this.arch, this.dim, this.hiddenDim, this.numberOfLayers, this.numberOfHeads,
|
|
this.numberOfKeyValueHeads, this.vocabularySize, newContextLength,
|
|
this.sharedWeights, this.rmsNormEps, this.ropeTheta);
|
|
}
|
|
|
|
}
|
|
|
|
public static final class Weights {
|
|
// token embedding table
|
|
public final FloatTensor token_embedding_table; // (vocab_size, dim)
|
|
// weights for rmsnorms
|
|
public final FloatBuffer[] rms_att_weight; // (layer, dim) rmsnorm weights
|
|
// weights for matmuls
|
|
public final FloatTensor[] wq; // (layer, n_heads * head_size)
|
|
public final FloatTensor[] wk; // (layer, n_kv_heads, head_size)
|
|
public final FloatTensor[] wv; // (layer, n_kv_heads * head_size)
|
|
public final FloatTensor[] wo; // (layer, n_heads * head_size, dim)
|
|
public final FloatTensor[] q_bias; // (layer, dim)
|
|
public final FloatTensor[] k_bias; // (layer, kv_dim)
|
|
public final FloatTensor[] v_bias; // (layer, kv_dim)
|
|
public final FloatBuffer[] rms_ffn_weight; // (layer, dim)
|
|
// weights for ffn
|
|
public final FloatTensor[] w1; // (layer, hidden_dim, dim)
|
|
public final FloatTensor[] w2; // (layer, dim, hidden_dim)
|
|
public final FloatTensor[] w3; // (layer, hidden_dim, dim)
|
|
// public final rmsnorm
|
|
public final FloatBuffer rms_final_weight; // (dim,)
|
|
// freq_cis for RoPE relatively positional embeddings
|
|
public final FloatBuffer freq_cis_real; // (seq_len, head_size/2)
|
|
public final FloatBuffer freq_cis_imag; // (seq_len, head_size/2)
|
|
// (optional) classifier weights for the logits, on the last layer
|
|
public final FloatTensor wcls; // (vocab_size, dim)
|
|
|
|
public Weights(FloatTensor token_embedding_table, FloatBuffer[] rms_att_weight, FloatTensor[] wq,
|
|
FloatTensor[] wk, FloatTensor[] wv,
|
|
FloatTensor[] q_bias, FloatTensor[] k_bias, FloatTensor[] v_bias,
|
|
FloatTensor[] wo, FloatBuffer[] rms_ffn_weight,
|
|
FloatTensor[] w1, FloatTensor[] w2, FloatTensor[] w3, FloatBuffer rms_final_weight,
|
|
FloatBuffer freq_cis_real, FloatBuffer freq_cis_imag, FloatTensor wcls) {
|
|
this.token_embedding_table = token_embedding_table;
|
|
this.rms_att_weight = rms_att_weight;
|
|
this.wq = wq;
|
|
this.wk = wk;
|
|
this.wv = wv;
|
|
this.q_bias = q_bias;
|
|
this.k_bias = k_bias;
|
|
this.v_bias = v_bias;
|
|
this.wo = wo;
|
|
this.rms_ffn_weight = rms_ffn_weight;
|
|
this.w1 = w1;
|
|
this.w2 = w2;
|
|
this.w3 = w3;
|
|
this.rms_final_weight = rms_final_weight;
|
|
this.freq_cis_real = freq_cis_real;
|
|
this.freq_cis_imag = freq_cis_imag;
|
|
this.wcls = wcls;
|
|
}
|
|
|
|
}
|
|
|
|
public static final class State {
|
|
|
|
// current wave of activations
|
|
public final int batchsize;
|
|
public final FloatTensor[] x; // activation at current time stamp (dim,)
|
|
public final FloatTensor[] xb; // same, but inside a residual branch (dim,)
|
|
public final FloatTensor[] xb2; // an additional buffer just for convenience (dim,)
|
|
public final FloatTensor[] hb; // buffer for hidden dimension in the ffn (hidden_dim,)
|
|
public final FloatTensor[] hb2; // buffer for hidden dimension in the ffn (hidden_dim,)
|
|
public final FloatTensor[] q; // query (dim,)
|
|
public final FloatTensor[] k; // key (dim,)
|
|
public final FloatTensor[] v; // value (dim,)
|
|
public final FloatTensor[] att; // buffer for scores/attention values (n_heads, seq_len)
|
|
public final FloatTensor logits; // output logits
|
|
|
|
// kv cache
|
|
public final FloatTensor[] keyCache; // (n_layer, seq_len, kv_dim)
|
|
public final FloatTensor[] valueCache; // (n_layer, seq_len, kv_dim)
|
|
|
|
/** last index in previous block */
|
|
int idxPrevBlock;
|
|
|
|
public int latestToken;
|
|
|
|
State(Configuration config, int batchsize) {
|
|
this.batchsize = batchsize;
|
|
this.x = allocate(batchsize, config.dim);
|
|
this.xb = allocate(batchsize, config.dim);
|
|
this.xb2 = allocate(batchsize, config.dim);
|
|
this.hb = allocate(batchsize, config.hiddenDim);
|
|
this.hb2 = allocate(batchsize, config.hiddenDim);
|
|
this.q = allocate(batchsize, config.dim);
|
|
this.k = allocate(batchsize, config.dim);
|
|
this.v = allocate(batchsize, config.dim);
|
|
this.att = allocate(batchsize, config.numberOfHeads, config.contextLength);
|
|
idxPrevBlock = -1;
|
|
|
|
this.logits = ArrayFloatTensor.allocate(config.vocabularySize);
|
|
int kvDim = (config.dim * config.numberOfKeyValueHeads) / config.numberOfHeads;
|
|
this.keyCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength, kvDim)).limit(config.numberOfLayers).toArray(FloatTensor[]::new);
|
|
this.valueCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength, kvDim)).limit(config.numberOfLayers).toArray(FloatTensor[]::new);
|
|
}
|
|
|
|
private static FloatTensor[] allocate(int numTokens, int... dims) {
|
|
return IntStream.range(0, numTokens)
|
|
.mapToObj(i -> ArrayFloatTensor.allocate(dims))
|
|
.toArray(FloatTensor[]::new);
|
|
}
|
|
|
|
}
|
|
|
|
static void rmsnorm(FloatTensor out, FloatTensor x, FloatBuffer weight, int size, float rmsNormEps) {
|
|
// calculate sum of squares
|
|
float ss = x.reduce(0, size, 0f, (acc, xi) -> acc + xi * xi);
|
|
ss /= size;
|
|
ss += rmsNormEps;
|
|
ss = (float) (1.0 / Math.sqrt(ss));
|
|
// normalize and scale
|
|
for (int i = 0; i < size; ++i) {
|
|
out.setFloat(i, weight.get(i) * (ss * x.getFloat(i)));
|
|
}
|
|
}
|
|
|
|
static FloatTensor forward(Llama model, State state, int[] tokens, int position, boolean computeLogits) {
|
|
// a few convenience variables
|
|
Configuration config = model.configuration();
|
|
Weights weights = model.weights();
|
|
int dim = config.dim;
|
|
int headSize = config.headSize;
|
|
int kvDim = (config.dim * config.numberOfKeyValueHeads) / config.numberOfHeads;
|
|
int kvMul = config.numberOfHeads / config.numberOfKeyValueHeads; // integer multiplier of the kv sharing in multiquery
|
|
float sqrtHeadSize = (float) Math.sqrt(headSize);
|
|
final int nTokens = tokens.length;
|
|
|
|
// copy the token embedding into x
|
|
FloatTensor.parallelFor(0, nTokens, t ->
|
|
weights.token_embedding_table.copyTo(tokens[t] * dim, state.x[t], 0, dim)
|
|
);
|
|
|
|
// forward all the layers
|
|
for (int l = 0; l < config.numberOfLayers; l++) {
|
|
|
|
// attention rmsnorm
|
|
final int curLayer = l;
|
|
FloatTensor.parallelFor(0, nTokens, t ->
|
|
rmsnorm(state.xb[t], state.x[t], weights.rms_att_weight[curLayer], dim, config.rmsNormEps)
|
|
);
|
|
|
|
// qkv matmuls for this position
|
|
weights.wq[l].matmul(nTokens, state.xb, state.q, dim, dim);
|
|
weights.wk[l].matmul(nTokens, state.xb, state.k, kvDim, dim);
|
|
weights.wv[l].matmul(nTokens, state.xb, state.v, kvDim, dim);
|
|
|
|
// RoPE relative positional encoding: complex-valued rotate q and k in each head
|
|
FloatTensor.parallelFor(0, nTokens, t -> {
|
|
for (int i = 0; i < dim; i += 2) {
|
|
int head_dim = i % headSize;
|
|
float fcr = weights.freq_cis_real.get((position + t) * (headSize / 2) + (head_dim / 2));
|
|
float fci = weights.freq_cis_imag.get((position + t) * (headSize / 2) + (head_dim / 2));
|
|
int rotn = i < kvDim ? 2 : 1; // how many vectors? 2 = q & k, 1 = q only
|
|
for (int vi = 0; vi < rotn; vi++) {
|
|
FloatTensor vec = vi == 0 ? state.q[t] : state.k[t]; // the vector to rotate (query or key)
|
|
float v0 = vec.getFloat(i);
|
|
float v1 = vec.getFloat(i + 1);
|
|
vec.setFloat(i, v0 * fcr - v1 * fci);
|
|
vec.setFloat(i + 1, v0 * fci + v1 * fcr);
|
|
}
|
|
}
|
|
});
|
|
|
|
// save key,value at this time step (position) to our kv cache
|
|
FloatTensor.parallelFor(0, nTokens, t -> {
|
|
state.k[t].copyTo(0, state.keyCache[curLayer], (position + t) * kvDim, kvDim);
|
|
state.v[t].copyTo(0, state.valueCache[curLayer], (position + t) * kvDim, kvDim);
|
|
});
|
|
|
|
// If the logits are not required, the attention and FFN of the last layer can be skipped entirely.
|
|
if (!computeLogits && curLayer == config.numberOfLayers - 1) {
|
|
state.idxPrevBlock = nTokens - 1;
|
|
return null;
|
|
}
|
|
|
|
// multihead attention. iterate over all heads
|
|
FloatTensor.parallelForLong(0, (long) nTokens * (long) config.numberOfHeads, ht -> {
|
|
int token = (int) (ht / config.numberOfHeads);
|
|
int h = (int) (ht % config.numberOfHeads);
|
|
int qOffset = h * headSize;
|
|
int attOffset = h * config.contextLength;
|
|
|
|
for (int t = 0; t <= position + token; t++) {
|
|
int keyCacheOffset = t * kvDim + (h / kvMul) * headSize;
|
|
float score = state.q[token].dot(qOffset, state.keyCache[curLayer], keyCacheOffset, headSize);
|
|
score /= sqrtHeadSize;
|
|
state.att[token].setFloat(attOffset + t, score);
|
|
}
|
|
|
|
state.att[token].softmaxInPlace(attOffset, position + token + 1);
|
|
|
|
int xbOffset = h * headSize;
|
|
state.xb[token].fillInPlace(xbOffset, headSize, 0f);
|
|
|
|
for (int t = 0; t <= position + token; t++) {
|
|
int vOffset = t * kvDim + (h / kvMul) * headSize;
|
|
float a = state.att[token].getFloat(attOffset + t);
|
|
state.xb[token].saxpyInPlace(xbOffset, state.valueCache[curLayer], vOffset, headSize, a);
|
|
}
|
|
});
|
|
|
|
// final matmul to get the output of the attention
|
|
weights.wo[l].matmul(nTokens, state.xb, state.xb2, dim, dim);
|
|
|
|
// residual connection back into x
|
|
FloatTensor.parallelFor(0, nTokens, t -> {
|
|
state.x[t].addInPlace(state.xb2[t]);
|
|
});
|
|
|
|
// ffn rmsnorm
|
|
FloatTensor.parallelFor(0, nTokens, t -> {
|
|
rmsnorm(state.xb[t], state.x[t], weights.rms_ffn_weight[curLayer], dim, config.rmsNormEps);
|
|
});
|
|
|
|
// Now for FFN in PyTorch we have: self.w2(F.silu(self.w1(x)) * self.w3(x))
|
|
weights.w1[l].matmul(nTokens, state.xb, state.hb, config.hiddenDim, dim);
|
|
weights.w3[l].matmul(nTokens, state.xb, state.hb2, config.hiddenDim, dim);
|
|
|
|
// SwiGLU non-linearity
|
|
FloatTensor.parallelFor(0, nTokens, t -> {
|
|
state.hb[t].mapInPlace(value -> value / (float) (1.0 + Math.exp(-value)));
|
|
});
|
|
|
|
// elementwise multiply with w3(x)
|
|
FloatTensor.parallelFor(0, nTokens, t -> {
|
|
state.hb[t].multiplyInPlace(state.hb2[t]);
|
|
});
|
|
|
|
// final matmul to get the output of the ffn
|
|
weights.w2[l].matmul(nTokens, state.hb, state.xb, dim, config.hiddenDim);
|
|
|
|
// residual connection
|
|
FloatTensor.parallelFor(0, nTokens, t -> {
|
|
state.x[t].addInPlace(state.xb[t]);
|
|
});
|
|
}
|
|
|
|
// final rmsnorm
|
|
FloatTensor.parallelFor(0, nTokens, t -> {
|
|
rmsnorm(state.x[t], state.x[t], weights.rms_final_weight, dim, config.rmsNormEps);
|
|
});
|
|
|
|
// classifier into logits
|
|
weights.wcls.matmul(state.x[nTokens - 1], state.logits, config.vocabularySize, dim);
|
|
state.idxPrevBlock = nTokens - 1;
|
|
|
|
return state.logits;
|
|
}
|
|
|
|
/**
|
|
* LLM generation entry point, ingest prompt tokens and generates new tokens.
|
|
*
|
|
* <p>
|
|
* All prompt tokens are ingested first, then inference starts, until a stop token is found.
|
|
* The returned tokens only include generated/inferred tokens.
|
|
*
|
|
* @param model model to run inference (including weights, configuration, tokenizer ...)
|
|
* @param state state of the model e.g. key/value caches ... this is mutated by this call
|
|
* @param startPosition start prompt ingestion + inference at this position in the context e.g. useful if state was kept across calls (chained generation). 0 implies run with no previous context.
|
|
* @param promptTokens prompt tokens to ingest, all the prompt tokens will be ingested, given there's enough capacity left in the context
|
|
* @param stopTokens set of tokens that abort generation during inference, stop tokens do not affect prompt ingestion
|
|
* @param maxTokens maximum number of tokens (can go up to {@link Configuration#contextLength context length}
|
|
* if this value is negative or greater than {@link Configuration#contextLength context length}
|
|
* @param sampler {@link Sampler strategy} used to select tokens
|
|
* @param echo debugging flag, prints ALL, prompt and inferred tokens, to {@link System#err stderr}
|
|
* @param onTokenGenerated callback, if non-null, it's called every time a token is inferred e.g. it's not called when ingesting prompt tokens
|
|
* @return list of generated/inferred tokens, including the stop token, if any e.g. does not include any token from the prompt
|
|
*/
|
|
public static List<Integer> generateTokens(Llama model, State state, int startPosition, List<Integer> promptTokens,
|
|
Set<Integer> stopTokens, int maxTokens, Sampler sampler,
|
|
IntConsumer onTokenGenerated) {
|
|
//long startNanos = System.nanoTime();
|
|
//long startGen = 0;
|
|
if (maxTokens < 0 || model.configuration().contextLength < maxTokens) {
|
|
maxTokens = model.configuration().contextLength;
|
|
}
|
|
List<Integer> generatedTokens = new ArrayList<>(maxTokens);
|
|
int token = state.latestToken; // BOS?
|
|
int nextToken;
|
|
int promptIndex = 0;
|
|
for (int position = startPosition; position < maxTokens; ++position) {
|
|
if (promptIndex < promptTokens.size()) {
|
|
final int nTokens = Math.min(maxTokens - position, Math.min(promptTokens.size() - promptIndex, state.batchsize));
|
|
final int[] tokens = new int[nTokens];
|
|
for (int i = 0; i < nTokens; i++) {
|
|
tokens[i] = promptTokens.get(promptIndex + i);
|
|
//System.out.print(Tokenizer.replaceControlCharacters(model.tokenizer().decode(List.of(tokens[i]))));
|
|
|
|
}
|
|
//System.out.format("position=%d, promptIdx=%d, promptSize=%d, tokens=%s%n", position, promptIndex, promptTokens.size(), Arrays.toString(tokens));
|
|
|
|
boolean computeLogits = promptIndex + nTokens >= promptTokens.size();
|
|
forward(model, state, tokens, position, computeLogits);
|
|
position += nTokens - 1;
|
|
promptIndex += nTokens;
|
|
if (promptIndex < promptTokens.size()) {
|
|
continue;
|
|
}
|
|
//startGen = System.nanoTime();
|
|
} else {
|
|
forward(model, state, new int[]{token}, position, true);
|
|
}
|
|
nextToken = sampler.sampleToken(state.logits);
|
|
// System.out.print(Tokenizer.replaceControlCharacters(model.tokenizer().decode(List.of(nextToken))));
|
|
generatedTokens.add(nextToken);
|
|
if (onTokenGenerated != null) {
|
|
onTokenGenerated.accept(nextToken);
|
|
}
|
|
if (stopTokens.contains(nextToken)) {
|
|
break;
|
|
}
|
|
state.latestToken = token = nextToken;
|
|
}
|
|
|
|
//long elapsedNanos = System.nanoTime() - startNanos;
|
|
//long promptNanos = startGen - startNanos;
|
|
//long genNanos = elapsedNanos - startGen + startNanos;
|
|
// System.out.printf(startPosition + promptIndex + generatedTokens.size(), model.configuration().contextLength, promptTokens.size() / (promptNanos / 1_000_000_000.0), promptTokens.size(), generatedTokens.size() / (genNanos / 1_000_000_000.0), generatedTokens.size());
|
|
return generatedTokens;
|
|
}
|
|
} |