-
Benedikt Schäfer authoredBenedikt Schäfer authored
use-chat.ts 5.25 KiB
'use client';
import { faker } from '@faker-js/faker';
import { useChat as useBaseChat } from 'ai/react';
import { useSettings } from '@/components/editor/settings';
import { createOllama } from 'ollama-ai-provider';
import { streamText } from 'ai';
export const useChat = () => {
const { keys, model } = useSettings();
return useBaseChat({
id: 'editor',
api: '/api/ai/command',
body: {
model: model.value,
},
fetch: async (input, init) => {
try {
// First try the normal API endpoint
const res = await fetch(input, init);
if (res.ok) return res;
// If API endpoint fails, fallback to direct Ollama call
const { messages } = await JSON.parse(init?.body as string || '{}');
const ollama = createOllama({
baseURL: 'http://localhost:11434/api'
});
const result = await streamText({
model: ollama(model.value || 'phi3'),
messages,
maxTokens: 2048,
temperature: 0.7,
});
return result.toDataStreamResponse({
headers: {
'Content-Type': 'text/event-stream',
'Cache-Control': 'no-cache',
'Connection': 'keep-alive',
}
});
} catch (error) {
console.error('Chat error:', error);
throw error;
}
},
});
};
// export const useChat = () => {
// const { keys, model } = useSettings();
// return useBaseChat({
// id: 'editor',
// api: '/api/ai/command',
// body: {
// // !!! DEMO ONLY: don't use API keys client-side
// // apiKey: keys.openai,
// model: model.value,
// },
// fetch: async (input, init) => {
// const res = await fetch(input, init);
// if (!res.ok) {
// // Mock the API response. Remove it when you implement the route /api/ai/command
// await new Promise((resolve) => setTimeout(resolve, 400));
// const stream = fakeStreamText();
// return new Response(stream, {
// headers: {
// Connection: 'keep-alive',
// 'Content-Type': 'text/plain',
// },
// });
// }
// return res;
// },
// });
// };
// Used for testing. Remove it after implementing useChat api.
const fakeStreamText = ({
chunkCount = 10,
streamProtocol = 'data',
}: {
chunkCount?: number;
streamProtocol?: 'data' | 'text';
} = {}) => {
const chunks = Array.from({ length: chunkCount }, () => ({
delay: faker.number.int({ max: 150, min: 50 }),
texts: faker.lorem.words({ max: 3, min: 1 }) + '',
}));
const encoder = new TextEncoder();
return new ReadableStream({
async start(controller) {
for (const chunk of chunks) {
await new Promise((resolve) => setTimeout(resolve, chunk.delay));
if (streamProtocol === 'text') {
controller.enqueue(encoder.encode(chunk.texts));
} else {
controller.enqueue(
encoder.encode(`0:${JSON.stringify(chunk.texts)}\n`)
);
}
}
if (streamProtocol === 'data') {
controller.enqueue(
`d:{"finishReason":"stop","usage":{"promptTokens":0,"completionTokens":${chunks.length}}}\n`
);
}
controller.close();
},
});
};
const streamOllamaText = async ({
messages,
model = 'phi3',
streamProtocol = 'data',
}: {
messages: any[];
model?: string;
streamProtocol?: 'data' | 'text';
}) => {
const ollama = createOllama({
baseURL: process.env.OLLAMA_BASE_URL || 'http://localhost:11434/api'
});
const encoder = new TextEncoder();
return new ReadableStream({
async start(controller) {
try {
const response = await fetch(`${ollama.baseURL}/chat`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({
model: model,
messages: messages,
stream: true,
}),
});
const reader = response.body?.getReader();
if (!reader) throw new Error('No reader available');
while (true) {
const { done, value } = await reader.read();
if (done) break;
const chunk = new TextDecoder().decode(value);
const lines = chunk.split('\n').filter(Boolean);
for (const line of lines) {
const json = JSON.parse(line);
if (streamProtocol === 'text') {
controller.enqueue(encoder.encode(json.message?.content || ''));
} else {
controller.enqueue(
encoder.encode(`0:${JSON.stringify(json.message?.content || '')}\n`)
);
}
}
}
if (streamProtocol === 'data') {
controller.enqueue(
encoder.encode(`d:{"finishReason":"stop","usage":{"promptTokens":0,"completionTokens":0}}\n`)
);
}
controller.close();
} catch (error) {
controller.error(error);
}
},
});
};
// Usage in API route
export async function POST(req: Request) {
const { messages } = await req.json();
const stream = await streamOllamaText({
messages,
model: 'phi3', // or any other Ollama model
streamProtocol: 'data'
});
return new StreamingTextResponse(stream);
}