mirror of
https://github.com/yacy/yacy_search_server.git
synced 2025-02-02 06:38:42 -05:00
246 lines
11 KiB
Java
246 lines
11 KiB
Java
/**
|
|
* RAGProxyServlet
|
|
* Copyright 2024 by Michael Peter Christen
|
|
* First released 17.05.2024 at https://yacy.net
|
|
*
|
|
* This library is free software; you can redistribute it and/or
|
|
* modify it under the terms of the GNU Lesser General Public
|
|
* License as published by the Free Software Foundation; either
|
|
* version 2.1 of the License, or (at your option) any later version.
|
|
*
|
|
* This library is distributed in the hope that it will be useful,
|
|
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
|
* Lesser General Public License for more details.
|
|
*
|
|
* You should have received a copy of the GNU Lesser General Public License
|
|
* along with this program in the file lgpl21.txt
|
|
* If not, see <http://www.gnu.org/licenses/>.
|
|
*/
|
|
|
|
package net.yacy.http.servlets;
|
|
|
|
import org.json.JSONArray;
|
|
import org.json.JSONException;
|
|
import org.json.JSONObject;
|
|
|
|
import net.yacy.ai.OpenAIClient;
|
|
import net.yacy.cora.federate.solr.connector.EmbeddedSolrConnector;
|
|
import net.yacy.search.Switchboard;
|
|
import net.yacy.search.schema.CollectionSchema;
|
|
|
|
import org.apache.solr.client.solrj.SolrQuery;
|
|
import org.apache.solr.common.SolrDocument;
|
|
import org.apache.solr.common.SolrDocumentList;
|
|
import org.apache.solr.common.SolrException;
|
|
import org.apache.solr.servlet.cache.Method;
|
|
|
|
import javax.servlet.ServletException;
|
|
import javax.servlet.ServletOutputStream;
|
|
import javax.servlet.ServletRequest;
|
|
import javax.servlet.ServletResponse;
|
|
import javax.servlet.http.HttpServlet;
|
|
import javax.servlet.http.HttpServletRequest;
|
|
import javax.servlet.http.HttpServletResponse;
|
|
import java.io.IOException;
|
|
import java.io.BufferedReader;
|
|
import java.io.InputStreamReader;
|
|
import java.io.OutputStream;
|
|
import java.net.HttpURLConnection;
|
|
import java.net.URI;
|
|
import java.net.URISyntaxException;
|
|
import java.net.URL;
|
|
import java.util.Iterator;
|
|
import java.util.LinkedHashMap;
|
|
|
|
/**
|
|
* This class implements a Retrieval Augmented Generation ("RAG") proxy which uses a YaCy search index
|
|
* to enrich a chat with search results. The
|
|
*/
|
|
public class RAGProxyServlet extends HttpServlet {
|
|
|
|
private static final long serialVersionUID = 3411544789759603107L;
|
|
|
|
//private static Boolean LLM_ENABLED = false;
|
|
//private static Boolean LLM_CONTROL_OLLAMA = true;
|
|
//private static Boolean LLM_ATTACH_QUERY = false; // instructs the proxy to attach the prompt generated to do the RAG search
|
|
//private static Boolean LLM_ATTACH_REFERENCES = false; // instructs the proxy to attach a list of sources that had been used in RAG
|
|
//private static String LLM_LANGUAGE = "en"; // used to select proper language in RAG augmentation
|
|
private static String LLM_SYSTEM_PREFIX = "\n\nYou may receive additional expert knowledge in the user prompt after a 'Additional Information' headline to enhance your knowledge. Use it only if applicable.";
|
|
private static String LLM_USER_PREFIX = "\n\nAdditional Information:\n\nbelow you find a collection of texts that might be useful to generate a response. Do not discuss these documents, just use them to answer the question above.\n\n";
|
|
private static String LLM_API_HOST = "http://localhost:11434"; // Ollama port; install ollama from https://ollama.com/
|
|
private static String LLM_QUERY_MODEL = "phi3:3.8b";
|
|
private static String LLM_ANSWER_MODEL = "llama3:8b"; // or "phi3:3.8b" i.e. on a Raspberry Pi 5
|
|
private static Boolean LLM_API_MODEL_OVERWRITING = true; // if true, the value configured in YaCy overwrites the client model
|
|
private static String LLM_API_KEY = ""; // not required; option to use this class to use a OpenAI API
|
|
|
|
@Override
|
|
public void service(ServletRequest request, ServletResponse response) throws IOException, ServletException {
|
|
response.setContentType("application/json;charset=utf-8");
|
|
|
|
HttpServletResponse hresponse = (HttpServletResponse) response;
|
|
HttpServletRequest hrequest = (HttpServletRequest) request;
|
|
|
|
// Add CORS headers
|
|
hresponse.setHeader("Access-Control-Allow-Origin", "*");
|
|
hresponse.setHeader("Access-Control-Allow-Methods", "POST, GET, OPTIONS, DELETE");
|
|
hresponse.setHeader("Access-Control-Allow-Headers", "Content-Type, Authorization");
|
|
|
|
final Method reqMethod = Method.getMethod(hrequest.getMethod());
|
|
if (reqMethod == Method.OTHER) {
|
|
// required to handle CORS
|
|
hresponse.setStatus(HttpServletResponse.SC_OK);
|
|
return;
|
|
}
|
|
|
|
// We expect a POST request
|
|
if (reqMethod != Method.POST) {
|
|
hresponse.sendError(HttpServletResponse.SC_METHOD_NOT_ALLOWED);
|
|
return;
|
|
}
|
|
|
|
// get the output stream early to be able to generate messages to the user before the actual retrieval starts
|
|
ServletOutputStream out = response.getOutputStream();
|
|
|
|
// read the body of the request and parse it as JSON
|
|
BufferedReader reader = request.getReader();
|
|
StringBuilder bodyBuilder = new StringBuilder();
|
|
String line;
|
|
while ((line = reader.readLine()) != null) {
|
|
bodyBuilder.append(line);
|
|
}
|
|
String body = bodyBuilder.toString();
|
|
JSONObject bodyObject;
|
|
try {
|
|
// get system message and user prompt
|
|
bodyObject = new JSONObject(body);
|
|
JSONArray messages = bodyObject.optJSONArray("messages");
|
|
JSONObject systemObject = messages.getJSONObject(0);
|
|
String system = systemObject.optString("content", ""); // the system prompt
|
|
JSONObject userObject = messages.getJSONObject(messages.length() - 1);
|
|
String user = userObject.optString("content", ""); // this is the latest prompt
|
|
|
|
// modify system and user prompt here in bodyObject to enable RAG
|
|
String query = searchWordsForPrompt(LLM_QUERY_MODEL, user);
|
|
out.print(responseLine("Searching for '" + query + "'\n\n").toString() + "\n"); out.flush();
|
|
LinkedHashMap<String, String> searchResults = searchResults(query, 4);
|
|
out.print(responseLine("Using the following sources for RAG:\n\n").toString() + "\n"); out.flush();
|
|
for (String s: searchResults.keySet()) {out.print(responseLine("- `" + s + "`\n").toString() + "\n"); out.flush();}
|
|
out.print(responseLine("\n").toString()); out.flush();
|
|
system += LLM_SYSTEM_PREFIX;
|
|
user += LLM_USER_PREFIX;
|
|
for (String s: searchResults.values()) user += s + "\n\n";
|
|
systemObject.put("content", system);
|
|
userObject.put("content", user);
|
|
|
|
if (LLM_API_MODEL_OVERWRITING) bodyObject.put("model", LLM_ANSWER_MODEL);
|
|
|
|
// write back modified bodyMap to body
|
|
body = bodyObject.toString();
|
|
|
|
// Open request to back-end service
|
|
URL url = new URI(LLM_API_HOST + "/v1/chat/completions").toURL();
|
|
HttpURLConnection conn = (HttpURLConnection) url.openConnection();
|
|
conn.setRequestMethod("POST");
|
|
conn.setRequestProperty("Content-Type", "application/json");
|
|
if (!LLM_API_KEY.isEmpty()) {
|
|
conn.setRequestProperty("Authorization", "Bearer " + LLM_API_KEY);
|
|
}
|
|
conn.setDoOutput(true);
|
|
|
|
// write the body to back-end LLM
|
|
try (OutputStream os = conn.getOutputStream()) {
|
|
os.write(body.getBytes());
|
|
os.flush();
|
|
}
|
|
|
|
// write back response of the back-end service to the client; use status of backend-response
|
|
int status = conn.getResponseCode();
|
|
//String rmessage = conn.getResponseMessage();
|
|
hresponse.setStatus(status);
|
|
|
|
if (status == 200) {
|
|
// read the response of the back-end line-by-line and write it to the client line-by-line
|
|
BufferedReader in = new BufferedReader(new InputStreamReader(conn.getInputStream()));
|
|
String inputLine;
|
|
while ((inputLine = in.readLine()) != null) {
|
|
out.print(inputLine); // i.e. data: {"id":"chatcmpl-69","object":"chat.completion.chunk","created":1715908287,"model":"llama3:8b","system_fingerprint":"fp_ollama","choices":[{"index":0,"delta":{"role":"assistant","content":"ߘ"},"finish_reason":null}]}
|
|
out.flush();
|
|
}
|
|
in.close();
|
|
}
|
|
out.close(); // close this here to end transmission
|
|
} catch (JSONException | URISyntaxException e) {
|
|
throw new IOException(e.getMessage());
|
|
}
|
|
}
|
|
|
|
public static LinkedHashMap<String, String> searchResults(String query, int count) {
|
|
Switchboard sb = Switchboard.getSwitchboard();
|
|
EmbeddedSolrConnector connector = sb.index.fulltext().getDefaultEmbeddedConnector();
|
|
// construct query
|
|
final SolrQuery params = new SolrQuery();
|
|
params.setQuery(CollectionSchema.text_t.getSolrFieldName() + ":" + query);
|
|
params.setRows(count);
|
|
params.setStart(0);
|
|
params.setFacet(false);
|
|
params.clearSorts();
|
|
params.setFields(CollectionSchema.sku.getSolrFieldName(), CollectionSchema.text_t.getSolrFieldName());
|
|
params.setIncludeScore(false);
|
|
params.set("df", CollectionSchema.text_t.getSolrFieldName());
|
|
|
|
// query the server
|
|
try {
|
|
final SolrDocumentList sdl = connector.getDocumentListByParams(params);
|
|
LinkedHashMap<String, String> a = new LinkedHashMap<String, String>();
|
|
Iterator<SolrDocument> i = sdl.iterator();
|
|
while (i.hasNext()) {
|
|
SolrDocument doc = i.next();
|
|
String url = (String) doc.getFieldValue(CollectionSchema.sku.getSolrFieldName());
|
|
String text = (String) doc.getFieldValue(CollectionSchema.text_t.getSolrFieldName());
|
|
a.put(url, text);
|
|
}
|
|
return a;
|
|
} catch (SolrException | IOException e) {
|
|
return new LinkedHashMap<String, String>();
|
|
}
|
|
}
|
|
|
|
private String searchWordsForPrompt(String model, String prompt) {
|
|
StringBuilder query = new StringBuilder();
|
|
String question = "Make a list of a maximum of four search words for the following question; use a JSON Array: " + prompt;
|
|
try {
|
|
OpenAIClient oaic = new OpenAIClient(LLM_API_HOST);
|
|
String[] a = OpenAIClient.stringsFromChat(oaic.chat(model, question, 80));
|
|
for (String s: a) query.append(s).append(' ');
|
|
return query.toString().trim();
|
|
} catch (IOException e) {
|
|
e.printStackTrace();
|
|
return "";
|
|
}
|
|
}
|
|
|
|
private static JSONObject responseLine(String payload) {
|
|
JSONObject j = new JSONObject(true);
|
|
try {
|
|
j.put("id", "log");
|
|
j.put("object", "chat.completion.chunk");
|
|
j.put("created", System.currentTimeMillis() / 1000);
|
|
j.put("model", "log");
|
|
j.put("system_fingerprint", "YaCy");
|
|
JSONArray choices = new JSONArray();
|
|
JSONObject choice = new JSONObject(true); // {"index":0,"delta":{"role":"assistant","content":"ߘ"
|
|
choice.put("index", 0);
|
|
JSONObject delta = new JSONObject(true);
|
|
delta.put("role", "assistant");
|
|
delta.put("content", payload);
|
|
choice.put("delta", delta);
|
|
choices.put(choice);
|
|
j.put("choices", choices);
|
|
//j.put("finish_reason", null); // this is problematic with the JSON library
|
|
} catch (JSONException e) {}
|
|
return j;
|
|
}
|
|
|
|
}
|