Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions REFAC_NOTES.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# 임베딩 리팩토링 설명 및 향후 작업 제안

## 변경 의도와 배경
- **전역 예외 처리 흐름을 고려한 단일 변환기 도입**: 사용자/유기동물 임베딩 JSON 직렬화·역직렬화가 여러 서비스에 중복되어 있었고, 변환 실패 시 `Optional.empty()` 반환 이후 상위 계층에서 공통적으로 대체 로직을 타도록 설계되어 있어, 전역 예외 처리 정책(checked 예외 대신 도메인 로그 남기고 graceful degradation)을 준수하도록 변환 로직을 한 곳(EmbeddingJsonConverter)으로 모았습니다. 이는 변환 예외가 전역 예외 핸들러까지 전파되어 과도한 에러 응답을 만드는 문제를 방지하려는 목적입니다.
- **ObjectMapper 직접 사용 제거 이유**: 개별 서비스마다 `ObjectMapper`를 직접 주입해 `readValue/writeValueAsString`을 호출하면서 try/catch, 로그 메시지, 컨텍스트 문자열 구성이 반복되고 있었습니다. EmbeddingJsonConverter는 같은 역할을 공통화해 중복을 제거하고, 동일한 로깅 포맷/컨텍스트를 유지하도록 합니다. 또한 향후 임베딩 포맷 변경 시 단일 지점만 수정하면 됩니다.
- **log.warn 사용 최소화 이유**: 기존에는 서비스마다 변환 실패 시 `log.warn`을 남기거나 예외를 던졌는데, 추천이나 사용자 임베딩 계산 흐름에서는 일부 벡터가 손상되어도 나머지 데이터로 결과를 만들 수 있도록 설계되었습니다. 따라서 변환 실패를 "예측 가능한 품질 저하"로 간주하고, 변환기 내부에서 컨텍스트를 포함해 한 번만 기록하며, 상위 계층에서는 불필요한 중복 경고 로그를 남기지 않도록 조정했습니다.

## 변경 요약
- **EmbeddingJsonConverter 추가**: 임베딩 JSON 직렬화/역직렬화를 담당하며, 입력 공백/파싱 실패 시 컨텍스트 기반 경고를 남기고 `Optional.empty()`를 반환하도록 구현했습니다.
- **RecommendationService 개선**: 사용자·동물 임베딩을 모두 변환기를 통해 읽어와 벡터 길이 검증 후 코사인 유사도를 계산하도록 단순화했습니다.
- **UserInterestService 개선**: 관심/좋아요 가중치 맵을 만든 뒤 동물 임베딩을 변환기를 거쳐 누적·정규화하고, 직렬화 역시 변환기를 사용해 중복 코드를 제거했습니다.

## 향후 리팩토링 제안
1. **추천 계산의 스트리밍/배치 최적화**: 대규모 동물 임베딩을 한 번에 메모리에 로드하지 않도록 paging 또는 벡터 DB 연동 검토.
2. **임베딩 차원/스케일 검증 유틸 추가**: 벡터 길이 불일치나 NaN 검출을 공통 유틸로 추출해, 저장 전/후 자동 검증하도록 개선.
3. **비동기 임베딩 생성 파이프라인 모니터링**: 큐/스케줄러 실패 시 알림 및 재시도 정책을 설정하고, 상태 테이블/헬스 체크 엔드포인트 추가.
4. **추천 결과 캐싱**: 사용자 임베딩 변경 시 무효화되는 캐시 레이어(e.g., Redis)를 두어 반복 호출 비용을 절감.
5. **통합 로깅 포맷 정비**: 변환기 로그와 서비스 로그를 동일한 traceId/컨텍스트로 묶을 수 있도록 MDC 활용 규칙 수립.

## 테스트 실패 원인
- `./gradlew test` 실행 시 Maven Central 의존성 다운로드 단계에서 HTTP 403 오류가 발생하여 테스트가 수행되지 않았습니다. 외부 네트워크 제약으로 인한 환경적 실패로, 소스 변경과 직접적인 연관은 없습니다.
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package com.ganzi.backend.global.embedding;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.util.Optional;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;

@Slf4j
@Component
@RequiredArgsConstructor
public class EmbeddingJsonConverter {

private final ObjectMapper objectMapper;

public Optional<float[]> toVector(String embeddingJson, String context) {
if (embeddingJson == null || embeddingJson.isBlank()) {
log.warn("임베딩 JSON이 비어 있습니다. context={}", context);
return Optional.empty();
}

try {
return Optional.of(objectMapper.readValue(embeddingJson, new TypeReference<float[]>() {}));
} catch (JsonProcessingException e) {
log.warn("임베딩 역직렬화 실패 context={}", context, e);
return Optional.empty();
}
}

public Optional<String> toJson(float[] vector, String context) {
try {
return Optional.of(objectMapper.writeValueAsString(vector));
} catch (JsonProcessingException e) {
log.error("임베딩 직렬화 실패 context={}", context, e);
return Optional.empty();
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package com.ganzi.backend.global.embedding;

import java.util.Optional;
import lombok.extern.slf4j.Slf4j;

@Slf4j
public class WeightedEmbeddingAggregator {

private final EmbeddingJsonConverter embeddingJsonConverter;
private float[] sum;
private double totalWeight;

public WeightedEmbeddingAggregator(EmbeddingJsonConverter embeddingJsonConverter) {
this.embeddingJsonConverter = embeddingJsonConverter;
}

public void add(String embeddingJson, double weight, String context) {
embeddingJsonConverter.toVector(embeddingJson, context)
.ifPresent(vector -> accumulate(vector, weight, context));
}

public Optional<float[]> normalizedAverage() {
if (sum == null || totalWeight == 0.0) {
return Optional.empty();
}

float[] average = new float[sum.length];
for (int i = 0; i < sum.length; i++) {
average[i] = sum[i] / (float) totalWeight;
}

double norm = 0.0;
for (float v : average) {
norm += v * v;
}

norm = Math.sqrt(norm);
if (norm > 0) {
for (int i = 0; i < average.length; i++) {
average[i] /= (float) norm;
}
}

return Optional.of(average);
}

private void accumulate(float[] vector, double weight, String context) {
if (sum == null) {
sum = new float[vector.length];
} else if (sum.length != vector.length) {
log.warn("임베딩 차원이 일치하지 않아 스킵합니다. context={}", context);
return;
}

for (int i = 0; i < vector.length; i++) {
sum[i] += vector[i] * (float) weight;
}
totalWeight += weight;
}
}
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
package com.ganzi.backend.recommendation.application;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.ganzi.backend.animal.domain.AnimalEmbedding;
import com.ganzi.backend.animal.domain.repository.AnimalEmbeddingRepository;
import com.ganzi.backend.user.domain.UserEmbedding;
import com.ganzi.backend.user.domain.repository.UserEmbeddingRepository;
import com.ganzi.backend.global.embedding.EmbeddingJsonConverter;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
Expand All @@ -15,7 +12,6 @@
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Optional;

@Slf4j
@Service
Expand All @@ -25,19 +21,13 @@ public class RecommendationService {

private final UserEmbeddingRepository userEmbeddingRepository;
private final AnimalEmbeddingRepository animalEmbeddingRepository;
private final ObjectMapper objectMapper;
private final EmbeddingJsonConverter embeddingJsonConverter;

public List<String> recommend(Long userId, int top) {
Optional<UserEmbedding> optUserEmbedding = userEmbeddingRepository.findByUserId(userId);
float[] userVector = null;
if (optUserEmbedding.isPresent()) {
try {
userVector = objectMapper.readValue(
optUserEmbedding.get().getEmbeddingJson(), new TypeReference<float[]>() {});
} catch (JsonProcessingException e) {
log.warn("user {} 임베딩 오류 : 역직렬화 실패", userId, e);
}
}
float[] userVector = userEmbeddingRepository.findByUserId(userId)
.flatMap(embedding -> embeddingJsonConverter.toVector(
embedding.getEmbeddingJson(), "userId=" + userId))
.orElse(null);

List<AnimalEmbedding> embeddings = animalEmbeddingRepository.findAll();

Expand All @@ -61,17 +51,14 @@ private List<RecommendationScore> calculateScores(float[] userVector, List<Anima
List<RecommendationScore> scores = new ArrayList<>();

for (AnimalEmbedding emb : embeddings) {
try {
float[] animalVector = objectMapper.readValue(
emb.getEmbeddingJson(), new TypeReference<float[]>() {});
if (animalVector.length != userVector.length) {
continue;
}
double score = cosineSimilarity(userVector, animalVector);
scores.add(new RecommendationScore(emb.getDesertionNo(), score));
} catch (JsonProcessingException e) {
log.warn("animal {} 임베딩 오류 : Score 계산 실패", emb.getDesertionNo(), e);
float[] animalVector = embeddingJsonConverter.toVector(
emb.getEmbeddingJson(), "animal desertionNo=" + emb.getDesertionNo())
.orElse(null);
if (animalVector == null || animalVector.length != userVector.length) {
continue;
}
double score = cosineSimilarity(userVector, animalVector);
scores.add(new RecommendationScore(emb.getDesertionNo(), score));
}

return scores;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
package com.ganzi.backend.user.application;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.ganzi.backend.animal.domain.Animal;
import com.ganzi.backend.animal.domain.AnimalEmbedding;
import com.ganzi.backend.animal.domain.repository.AnimalEmbeddingRepository;
import com.ganzi.backend.animal.domain.repository.AnimalRepository;
import com.ganzi.backend.global.code.status.ErrorStatus;
import com.ganzi.backend.global.exception.GeneralException;
import com.ganzi.backend.global.embedding.EmbeddingJsonConverter;
import com.ganzi.backend.global.embedding.WeightedEmbeddingAggregator;
import com.ganzi.backend.user.domain.User;
import com.ganzi.backend.user.domain.UserEmbedding;
import com.ganzi.backend.user.domain.UserInterest;
Expand Down Expand Up @@ -40,7 +39,7 @@ public class UserInterestService {
private final AnimalRepository animalRepository;
private final AnimalEmbeddingRepository animalEmbeddingRepository;
private final UserLikeRepository userLikeRepository;
private final ObjectMapper objectMapper;
private final EmbeddingJsonConverter embeddingJsonConverter;


@Transactional
Expand Down Expand Up @@ -70,69 +69,38 @@ public void computeUserEmbedding(Long userId) {
.orElseThrow(() -> new GeneralException(ErrorStatus.USER_NOT_FOUND));

Map<String, Double> weightMap = buildWeightMap(user);
if(weightMap.isEmpty()) {
if (weightMap.isEmpty()) {
return;
}

float[] sum = null;
double totalWeight = 0.0;
for (Map.Entry<String, Double> entry : weightMap.entrySet()) {
String deserNo = entry.getKey();
double weight = entry.getValue();
Optional<AnimalEmbedding> optEmbedding = animalEmbeddingRepository.findById(deserNo);
if (optEmbedding.isEmpty()) {
continue;
}
float[] vector;
try {
vector = objectMapper.readValue(
optEmbedding.get().getEmbeddingJson(),
new TypeReference<float[]>() {}
);
} catch (JsonProcessingException e) {
log.warn("Animal Embedding 역직렬화 실패 desertionNo={}", deserNo, e);
continue;
}
if (sum == null) {
sum = new float[vector.length];
}
for (int i = 0; i < vector.length; i++) {
sum[i] += (float) (vector[i] * weight);
}
totalWeight += weight;
}

if (sum == null || totalWeight == 0.0) {
Optional<float[]> normalizedAverage = aggregateUserEmbedding(weightMap);
if (normalizedAverage.isEmpty()) {
return;
}

for (int i = 0; i < sum.length; i++) {
sum[i] /= (float) totalWeight;
}

// L2 정규화
double norm = 0.0;
for (float v : sum) {
norm += v * v;
}
norm = Math.sqrt(norm);
if (norm > 0) {
for (int i = 0; i < sum.length; i++) {
sum[i] /= (float) norm;
}
}
float[] userEmbeddingVector = normalizedAverage.get();

UserEmbedding userEmbedding = userEmbeddingRepository.findByUserId(user.getId())
.orElseGet(() -> UserEmbedding.builder().user(user).build());

try {
String json = objectMapper.writeValueAsString(sum);
userEmbedding.updateUserEmbedding(json, sum.length);
userEmbeddingRepository.save(userEmbedding);
} catch (JsonProcessingException e) {
log.error("User Embedding 직렬화 실패 userId={}", user.getId(), e);
throw new GeneralException(ErrorStatus.DATABASE_ERROR);
String embeddingJson = embeddingJsonConverter.toJson(userEmbeddingVector, "userId=" + user.getId())
.orElseThrow(() -> new GeneralException(ErrorStatus.DATABASE_ERROR));

userEmbedding.updateUserEmbedding(embeddingJson, userEmbeddingVector.length);
userEmbeddingRepository.save(userEmbedding);
}

private Optional<float[]> aggregateUserEmbedding(Map<String, Double> weightMap) {
WeightedEmbeddingAggregator aggregator = new WeightedEmbeddingAggregator(embeddingJsonConverter);

for (Map.Entry<String, Double> entry : weightMap.entrySet()) {
String desertionNo = entry.getKey();
double weight = entry.getValue();
animalEmbeddingRepository.findById(desertionNo)
.map(AnimalEmbedding::getEmbeddingJson)
.ifPresent(json -> aggregator.add(json, weight, "animal desertionNo=" + desertionNo));
}

return aggregator.normalizedAverage();
}


Expand Down