package com.hdl.photovoltaic.internet;
|
|
import android.text.TextUtils;
|
import android.util.Log;
|
|
import com.google.gson.Gson;
|
import com.google.gson.JsonObject;
|
import com.google.gson.JsonParser;
|
import com.hdl.photovoltaic.config.UserConfigManage;
|
|
import org.jetbrains.annotations.NotNull;
|
|
import java.io.BufferedReader;
|
import java.io.IOException;
|
import java.util.Objects;
|
import java.util.concurrent.TimeUnit;
|
|
import okhttp3.Cache;
|
import okhttp3.Call;
|
import okhttp3.Callback;
|
import okhttp3.MediaType;
|
import okhttp3.OkHttpClient;
|
import okhttp3.Request;
|
import okhttp3.RequestBody;
|
import okhttp3.Response;
|
import okhttp3.ResponseBody;
|
import okio.Buffer;
|
|
/**
|
* AI 聊天流式请求工具类
|
* 支持 SSE (Server-Sent Events) 流式响应
|
* 类似 ChatGPT 的流式输出效果
|
*/
|
public class ChatStreamClient {
|
// 单例实例
|
private static volatile ChatStreamClient instance;
|
|
/**
|
* 获取单例实例
|
*/
|
public static ChatStreamClient getInstance() {
|
if (instance == null) {
|
synchronized (ChatStreamClient.class) {
|
if (instance == null) {
|
instance = new ChatStreamClient();
|
}
|
}
|
}
|
return instance;
|
}
|
|
// ==================== 常量定义 ====================
|
private static final MediaType JSON = MediaType.get("application/json");
|
private static final String SSE_MEDIA_TYPE = "text/event-stream";
|
private static final String JSON_MEDIA_TYPE = "application/json";
|
private static final String DONE_FLAG = "[DONE]";
|
|
private final OkHttpClient okHttpClient = new OkHttpClient.Builder()
|
.connectTimeout(30, TimeUnit.SECONDS)
|
.readTimeout(0, TimeUnit.SECONDS)
|
.writeTimeout(30, TimeUnit.SECONDS)
|
// .connectionPool(new ConnectionPool(
|
// builder.maxIdleConnections,
|
// builder.keepAliveDuration,
|
// builder.timeUnit
|
// ))
|
// .retryOnConnectionFailure(builder.retryOnFailure)
|
// .addInterceptor(new HttpLoggingInterceptor()) // 可选:添加日志
|
.build();
|
; // HTTP 客户端
|
private final Gson gson = new Gson(); // JSON 解析器
|
private final String apiKey = "Bearer " + UserConfigManage.getInstance().getAgentSecret(); // API 密钥
|
private final String baseUrl = UserConfigManage.getInstance().getAgentUrl();// "https://agent.hdlcontrol.com/v1";; // 基础 URL
|
|
|
// ==================== 回调接口 ====================
|
public interface ChatCallback {
|
/**
|
* 收到消息片段时回调(流式输出)
|
*
|
* @param content 消息内容片段
|
*/
|
void onMessage(String content);
|
|
/**
|
* 消息完成时回调
|
*/
|
default void onComplete() {
|
}
|
|
/**
|
* 发生错误时回调
|
*
|
* @param error 错误信息
|
*/
|
default void onError(String error) {
|
}
|
|
/**
|
* 收到完整消息时回调(非流式模式使用)
|
*
|
* @param fullMessage 完整消息
|
*/
|
default void onFullMessage(String fullMessage) {
|
}
|
}
|
|
// ==================== 请求参数类 ====================
|
public static class ChatMode {
|
public boolean stream = true;
|
public boolean isGet = false;
|
public String url = "";
|
public Object data = null;
|
}
|
|
|
/**
|
* 发送流式聊天请求(完整参数)
|
*
|
* @param chatMode 请求参数
|
* @param callback 回调接口
|
* @return Cancelable 可取消的对象
|
*/
|
public Cancelable streamChat(ChatMode chatMode, ChatCallback callback) {
|
// // 确保是流式请求
|
// chatMode.stream = true;
|
// 构建 HTTP 请求
|
Request httpRequest = buildHttpRequest(chatMode);
|
// try {
|
// // 获取请求体
|
// if (httpRequest.body() != null) {
|
// Buffer buffer = new Buffer();
|
// httpRequest.body().writeTo(buffer);
|
// String body = buffer.readUtf8();
|
// // 注意:读取后记得关闭 buffer
|
// buffer.close();
|
// System.out.println("Request Body: " + httpRequest.url() + "\r\n" + body);
|
// } else {
|
// System.out.println("Request Body: " + httpRequest.url());
|
// }
|
// } catch (Exception e) {
|
// }
|
// 创建可取消的 Call
|
Call call = okHttpClient.newCall(httpRequest);
|
|
// 执行异步请求
|
call.enqueue(new StreamCallbackHandler(call, chatMode, callback));
|
|
// 返回可取消对象
|
return () -> {
|
if (!call.isCanceled()) {
|
call.cancel();
|
}
|
};
|
}
|
|
/**
|
* 发送非流式聊天请求(一次性返回)
|
*
|
* @param request 请求参数
|
* @return 完整响应
|
*/
|
public String chatSync(ChatMode request) {
|
// request.stream = false;
|
Request httpRequest = buildHttpRequest(request);
|
try (Response response = okHttpClient.newCall(httpRequest).execute()) {
|
if (!response.isSuccessful()) {
|
return response.message() + "(" + response.code() + ")";
|
}
|
return Objects.requireNonNull(response.body()).string();
|
} catch (Exception e) {
|
return e.getMessage();
|
}
|
}
|
|
|
/**
|
* 构建 HTTP 请求
|
*/
|
private Request buildHttpRequest(ChatMode ChatMode) {
|
|
String jsonBody = "";
|
if (ChatMode.data != null) {
|
jsonBody = gson.toJson(ChatMode.data);
|
}
|
String newUrl = baseUrl + ChatMode.url;
|
if (ChatMode.isGet) {
|
return new Request.Builder()
|
.url(newUrl)
|
.get()
|
.addHeader("Authorization", apiKey)
|
.addHeader("Cache-Control", "no-cache")
|
.addHeader("Connection", "keep-alive")
|
.build();
|
} else {
|
return new Request.Builder()
|
.url(newUrl)
|
.post(RequestBody.create(jsonBody, JSON))
|
.addHeader("Authorization", apiKey)
|
.addHeader("Accept", ChatMode.stream ? SSE_MEDIA_TYPE : JSON_MEDIA_TYPE)
|
.addHeader("Cache-Control", "no-cache")
|
.addHeader("Connection", "keep-alive")
|
.build();
|
}
|
}
|
|
|
/**
|
* 解析流式数据块
|
*/
|
private String parseStreamChunk(String data) {
|
if (data == null || data.isEmpty() || data.equals(DONE_FLAG)) {
|
return "";
|
}
|
try {
|
JsonObject json = JsonParser.parseString(data).getAsJsonObject();
|
String event = json.has("event") ? json.get("event").getAsString() : "";
|
if (event.equals("message")) {
|
return json.getAsString();
|
} else if (event.equals("message_end")) {
|
return DONE_FLAG;
|
} else if (event.equals("error")) {
|
return "error";
|
} else {
|
return "";
|
}
|
|
} catch (Exception e) {
|
// 解析失败,返回原始数据
|
return data;
|
}
|
}
|
|
// ==================== 流式响应处理器 ====================
|
private class StreamCallbackHandler implements Callback {
|
private final Call call;
|
private final ChatMode request;
|
private final ChatCallback callback;
|
private final StringBuilder fullContent = new StringBuilder();
|
|
public StreamCallbackHandler(Call call, ChatMode request, ChatCallback callback) {
|
this.call = call;
|
this.request = request;
|
this.callback = callback;
|
}
|
|
@Override
|
public void onFailure(@NotNull Call call, @NotNull IOException e) {
|
callback.onError("Network error: " + e.getMessage());
|
|
}
|
|
@Override
|
public void onResponse(@NotNull Call call, @NotNull Response response) {
|
if (!response.isSuccessful()) {
|
callback.onError("HTTP error: " + response.code());
|
response.close();
|
return;
|
}
|
// 检查内容类型
|
MediaType contentType = response.body().contentType();
|
if (contentType == null || !contentType.toString().startsWith(SSE_MEDIA_TYPE)) {
|
// 如果不是流式,可能是普通 JSON
|
try {
|
String body = response.body().string();
|
callback.onFullMessage(body);
|
Log.d("普通回复===", body);
|
} catch (IOException e) {
|
callback.onError("Parse error: " + e.getMessage());
|
}
|
response.close();
|
return;
|
}
|
|
// 流式处理
|
try (ResponseBody responseBody = response.body()) {
|
BufferedReader reader = new BufferedReader(responseBody.charStream());
|
String line;
|
while ((line = reader.readLine()) != null) {
|
if (call.isCanceled()) {
|
break;
|
}
|
|
if (line.startsWith("data:")) {
|
String data = line.substring(5).trim();
|
Log.d("流式处理===", line);
|
if (data.equals(DONE_FLAG)) {
|
callback.onComplete();
|
break;
|
}
|
String content = parseStreamChunk(data);
|
if (!TextUtils.isEmpty(content)) {
|
if (content.equals(DONE_FLAG)) {
|
callback.onComplete();
|
break;
|
} else if (content.equals("error")) {
|
callback.onError(data);
|
break;
|
}
|
fullContent.append(content);
|
callback.onMessage(content);
|
}
|
}
|
}
|
|
// // 如果没收到 DONE 但流结束了,也回调 complete
|
// if (isActive.get()) {
|
// callback.onComplete();
|
// }
|
|
} catch (IOException e) {
|
callback.onError("Stream error: " + e.getMessage());
|
}
|
}
|
}
|
|
// ==================== 可取消接口 ====================
|
public interface Cancelable {
|
void cancel();
|
}
|
|
|
/**
|
* 释放资源(应用退出时调用)
|
*/
|
public void shutdown() {
|
okHttpClient.dispatcher().executorService().shutdown();
|
okHttpClient.connectionPool().evictAll();
|
try {
|
Cache cache = okHttpClient.cache();
|
if (cache != null) {
|
cache.close();
|
}
|
} catch (IOException e) {
|
e.printStackTrace();
|
}
|
}
|
}
|