반응형
해당 포스팅은 AI허브 사이트에서 제공된 사용자 대화 데이터를 이용합니다.
해당 포스팅은 이전 포스팅과 연관되어 있습니다.
해당 포스팅은 streamlit 을 이용하여 웹 대시보드에서 AI봇과 대화를 나누는 것을 구현합니다.
라이브러리 및 데이터셋이 준비가 되어 있지 않다면 이전 포스팅을 확인해주세요.
AI Hub의 데이터셋을 이용하여 챗봇 만들기 - 1. 데이터 준비 및 라이브러리 설치
완성본 소스 파일 및 데이터셋
https://github.com/luvris2/streamlit_chatbot
- 파이썬 파일
- app.py : 메인 파일
- simple_chatbot.ipynb : 모델 로드와 데이터 가공, 전처리 과정이 담겨 있는 파일
- test.py : 간단한 테스트 파일, streamlit_chat 라이브러리 호출 확인용
- 추가 데이터 파일
- AI_chatbot.pkl : 모델링 정보를 저장, 캐쉬를 사용하지 않고 사용 할 때 사용하기 위함 (선택)
- wellness_dataset ~ .csv : AI Hub에서 제공하는 웰니스 대화 데이터셋, 원본 파일과 가공된 파일
# 소스 코드
라이브러리 호출
import streamlit as st
from streamlit_chat import message
import pandas as pd
from sentence_transformers import SentenceTransformer #sentenceBERT 모델 사용
from sklearn.metrics.pairwise import cosine_similarity
import json
캐쉬 정의
@st.cache(allow_output_mutation=True)
def cached_model():
model = SentenceTransformer('jhgan/ko-sroberta-multitask') # 해당 모델 사용
return model
@st.cache(allow_output_mutation=True)
def get_dataset():
df = pd.read_csv('data/wellness_dataset.csv') # AI Hub에서 제공하는 웰니스 데이터셋 사용
df['embedding'] = df['embedding'].apply(json.loads)
return df
model = cached_model()
df = get_dataset()
메인 소스 코드
st.header('심리상담 챗봇')
if 'generated' not in st.session_state:
st.session_state['generated'] = []
if 'past' not in st.session_state:
st.session_state['past'] = []
# 텍스트를 입력하여 봇과 대화 할 수 있는 폼 생성
# clear_on_submit 옵션을 통해서 submit 하면 폼의 내용이 지워짐
with st.form('form', clear_on_submit=True):
user_input = st.text_input('당신 : ', '')
submitted = st.form_submit_button('전송')
# 메시지를 입력 후 전송을 누를 경우
if submitted and user_input:
embedding = model.encode(user_input) # 유저가 입력한 문장을 벡터라이징
# 입력한 메시지의 유사도를 확인하여 가장 유사한 답변을 제시
df['similarity'] = df['embedding'].map(lambda x: cosine_similarity([embedding], [x]).squeeze())
answer = df.loc[ df['similarity'].idxmax() ] # 가장 유사한 답변을 저장
# 유저와 챗봇의 대화 내용을 저장
st.session_state.past.append(user_input)
st.session_state.generated.append(answer['챗봇'])
# 저장된 대화 내용 보여주기
for i in range(len(st.session_state['past'])):
message(st.session_state['past'][i], is_user=True, key=str(i) + '_user')
if len(st.session_state['generated']) > i:
message(st.session_state['generated'][i], key=str(i) + '_bot')
# 구현 화면
반응형