Test-0-1's picture
Sentence Tranformer for stocksense
9f8f782 verified
import os
# Use /tmp/hf_cache as the cache directory (ephemeral storage, no admin privileges needed)
cache_dir = "/tmp/hf_cache"
if not os.path.exists(cache_dir):
os.makedirs(cache_dir, exist_ok=True)
# Ensure the 'hub' subdirectory exists
hub_dir = os.path.join(cache_dir, "hub")
if not os.path.exists(hub_dir):
os.makedirs(hub_dir, exist_ok=True)
# Set the cache location for Hugging Face libraries
os.environ["HF_HOME"] = cache_dir
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List
from sentence_transformers import SentenceTransformer
from fastapi.middleware.cors import CORSMiddleware
# Initialize FastAPI app
app = FastAPI()
# Enable CORS so external applications can access the API
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Load the Sentence Transformer model
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
# Define the request model for batch encoding
class TextBatchInput(BaseModel):
texts: List[str]
input_type: str
# Endpoint to encode a batch of texts
@app.post("/encode_batch")
async def encode_batch(input: TextBatchInput):
try:
embeddings = model.encode(input.texts, input_type=input.input_type)
return {"embeddings": embeddings.tolist()}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# A simple root endpoint to verify the API is running
@app.get("/")
async def root():
return {"message": "Sentence Transformer API is running."}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)