agencies

c 소스코드 추상화 해보기 (변수, 파라미터) 본문

Ⅰ. 프로그래밍

c 소스코드 추상화 해보기 (변수, 파라미터)

agencies 2024. 11. 28. 13:29

변수 : var

파라미터 : param 으로 추상화 진행

 

소스코드

from pycparser import c_parser, c_ast
import re

def remove_preprocessor_directives(code):
    """
    C 코드에서 주석 및 전처리 지시문 제거
    """
    # 주석 제거
    code = re.sub(r'//.*?\n|/\*.*?\*/', '', code, flags=re.S)
    # 전처리 지시문 제거
    code = "\n".join([line for line in code.splitlines() if not line.strip().startswith("#")])
    return code

class Abstractor(c_ast.NodeVisitor):
    def __init__(self):
        self.param_map = {}
        self.var_map = {}
        self.func_names = set()  # 함수 이름 저장

    def visit_FuncDef(self, node):
        # 함수 이름 저장
        self.func_names.add(node.decl.name)
        
        # 함수 매개변수 추상화
        if node.decl.type.args:
            for param in node.decl.type.args.params:
                self.param_map[param.name] = "param"

        self.generic_visit(node)

    def visit_Decl(self, node):
        # 내부 변수 이름 추상화
        if node.name not in self.param_map and node.name not in self.func_names:
            if node.name not in self.var_map:  # 이미 치환된 이름이 아니면
                self.var_map[node.name] = "var"

        self.generic_visit(node)

    def abstract_code(self, code):
        # 매개변수 및 내부 변수 이름을 치환
        abstracted_code = code
        # 매개변수 이름 치환
        for original, abstracted in self.param_map.items():
            abstracted_code = self._replace_safe(abstracted_code, original, abstracted)
        # 내부 변수 이름 치환
        for original, abstracted in self.var_map.items():
            abstracted_code = self._replace_safe(abstracted_code, original, abstracted)
        return abstracted_code

    def _replace_safe(self, code, old, new):
        """
        안전하게 변수명을 교체하기 위한 함수.
        함수 이름과 충돌하지 않도록 전체 단어만 교체.
        """
        pattern = r'\b' + re.escape(old) + r'\b'
        return re.sub(pattern, new, code)

# 원본 코드
with open("test.c", "r", encoding="utf-8") as code:
    c_code = code.read()

print("원본 코드:")
print(c_code)

# 전처리 지시문 제거
processed_code = remove_preprocessor_directives(c_code)

# 파싱 및 추상화
parser = c_parser.CParser()
ast = parser.parse(processed_code)

abstractor = Abstractor()
abstractor.visit(ast)

# 추상화된 코드 출력
abstracted_code = abstractor.abstract_code(processed_code)
print("추상화된 코드:")
print(abstracted_code)

 

 

추상화 전

 

추상화된 코드

 


고도화! 진행

공백 제거 + printf 내용 "내용" -> text로 추상화

해시처리 및 csv 저장 

import hashlib
import csv
from pycparser import c_parser, c_ast
import re

def remove_preprocessor_directives(code):
    """
    C 코드에서 주석 및 전처리 지시문 제거
    """
    # 주석 제거
    code = re.sub(r'//.*?\n|/\*.*?\*/', '', code, flags=re.S)
    # 전처리 지시문 제거
    code = "\n".join([line for line in code.splitlines() if not line.strip().startswith("#")])
    return code

def split_statements(code):
    """
    ;로 끝나는 구문을 기준으로 한 줄씩 분리
    """
    lines = code.splitlines()
    split_lines = []
    for line in lines:
        statements = line.split(";")
        for i, statement in enumerate(statements):
            statement = statement.strip()
            if statement:  # 비어 있지 않은 경우만 추가
                if i < len(statements) - 1:  # 마지막 문장이 아니라면 ; 추가
                    split_lines.append(statement + ";")
                else:
                    split_lines.append(statement)
    return split_lines

def transform_printf(line):
    """
    printf 구문을 분석하여 문자열 리터럴을 'text'로 변환
    """
    printf_pattern = r'printf\s*\((.*?)\);'
    match = re.match(printf_pattern, line)
    if match:
        content = match.group(1)
        # 문자열 리터럴("...")을 text로 변환
        transformed = re.sub(r'"[^"]*"', 'text', content)
        return f'printf({transformed});'
    return line

class Abstractor(c_ast.NodeVisitor):
    def __init__(self):
        self.param_map = {}
        self.var_map = {}
        self.func_names = set()  # 함수 이름 저장

    def visit_FuncDef(self, node):
        self.func_names.add(node.decl.name)
        if node.decl.type.args:
            for param in node.decl.type.args.params:
                self.param_map[param.name] = "param"

        self.generic_visit(node)

    def visit_Decl(self, node):
        if node.name not in self.param_map and node.name not in self.func_names:
            if node.name not in self.var_map:
                self.var_map[node.name] = "var"

        self.generic_visit(node)

    def abstract_code(self, code):
        abstracted_code = code
        for original, abstracted in self.param_map.items():
            abstracted_code = self._replace_safe(abstracted_code, original, abstracted)
        for original, abstracted in self.var_map.items():
            abstracted_code = self._replace_safe(abstracted_code, original, abstracted)
        return abstracted_code

    def _replace_safe(self, code, old, new):
        """
        안전하게 변수명을 교체하기 위한 함수.
        """
        pattern = r'\b' + re.escape(old) + r'\b'
        return re.sub(pattern, new, code)

def hash_line(line):
    """
    스페이스바 제거 후 주어진 한 줄의 코드를 해싱합니다.
    """
    normalized_line = line.replace(" ", "").replace("\t", "")  # 공백 및 탭 제거
    return hashlib.sha256(normalized_line.encode('utf-8')).hexdigest()

# 원본 코드
with open("test.c", "r", encoding="utf-8") as code:
    c_code = code.read()

print("원본 코드:")
print(c_code)

# 전처리 지시문 제거
processed_code = remove_preprocessor_directives(c_code)

# ;로 구분하여 각 문장을 한 줄씩 분리
split_code = "\n".join(split_statements(processed_code))

# printf 구문 변환
transformed_lines = [transform_printf(line) for line in split_code.splitlines()]
transformed_code = "\n".join(transformed_lines)

# 파싱 및 추상화
parser = c_parser.CParser()
ast = parser.parse(transformed_code)

abstractor = Abstractor()
abstractor.visit(ast)

# 추상화된 코드 출력
abstracted_code = abstractor.abstract_code(transformed_code)

# 각 줄의 코드와 해시값 추출
rows = []
for line in abstracted_code.splitlines():
    line = line.strip()  # 각 줄의 앞뒤 공백 제거
    if line:  # 빈 줄은 무시
        hashed_line = hash_line(line)
        rows.append({"abs_code": line.replace(" ", "").replace("\t", ""), "hash": hashed_line})

# 결과를 CSV로 저장
output_file = "abstracted_code.csv"
with open(output_file, "w", newline="", encoding="utf-8") as csvfile:
    fieldnames = ["abs_code", "hash"]
    writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
    writer.writeheader()
    writer.writerows(rows)

print(f"추상화된 코드와 해시값이 {output_file}에 저장되었습니다.")

 

 


 

(업데이트)

backward slicing + 추상화

 

import subprocess
import json
import hashlib
import csv
from pycparser import c_parser, c_ast
import re
import os

# Joern 파일 경로 및 함수 이름
file_name = "/content/test.c"
func_name = "memory_leak"
csv_file_name = "/content/backward-db.csv"

class Abstractor(c_ast.NodeVisitor):
    def __init__(self):
        self.param_map = {}
        self.var_map = {}
        self.func_names = set()

    def visit_FuncDef(self, node):
        self.func_names.add(node.decl.name)
        if node.decl.type.args:
            for param in node.decl.type.args.params:
                self.param_map[param.name] = "param"
        self.generic_visit(node)

    def visit_Decl(self, node):
        if node.name not in self.param_map and node.name not in self.func_names:
            if node.name not in self.var_map:
                self.var_map[node.name] = "var"
        self.generic_visit(node)

    def abstract_code(self, code):
        abstracted_code = code
        for original, abstracted in self.param_map.items():
            abstracted_code = self._replace_safe(abstracted_code, original, abstracted)
        for original, abstracted in self.var_map.items():
            abstracted_code = self._replace_safe(abstracted_code, original, abstracted)
        return abstracted_code

    def _replace_safe(self, code, old, new):
        pattern = r'\b' + re.escape(old) + r'\b'
        return re.sub(pattern, new, code)

def remove_comments_and_preprocessor(code):
    """
    주석 및 전처리 지시문 제거 후 개행을 유지합니다.
    """
    # 라인 주석 제거
    code = re.sub(r'//.*', '', code)
    # 블록 주석 제거
    code = re.sub(r'/\*.*?\*/', lambda match: '\n' * match.group(0).count('\n'), code, flags=re.S)
    # 전처리 지시문(#include 등) 제거
    code = re.sub(r'#.*', '', code)
    return code

def process_code(input_code):
    """
    주어진 C 코드를 추상화하고 결과를 반환합니다.
    """
    # 주석 및 전처리 지시문 제거 후 개행 유지
    code_without_comments = remove_comments_and_preprocessor(input_code)

    # 파싱 및 추상화
    parser = c_parser.CParser()
    abstractor = Abstractor()

    try:
        ast = parser.parse(code_without_comments)
        abstractor.visit(ast)
        abstracted_code = abstractor.abstract_code(code_without_comments)
        return abstracted_code
    except Exception as e:
        print(f"코드 파싱 오류: {e}")
        return code_without_comments  # 오류 발생 시 원본 반환

# 사용자 코드 읽기
with open(file_name, "r", encoding='utf-8') as user_code_file:
    user_code = user_code_file.read()

# 코드 추상화
abstracted_code = process_code(user_code)

# 추상화된 코드 저장
output_file = file_name.split('.c')[0] + '_tmp.c'
with open(output_file, "w", encoding="utf-8") as file:
    file.write(abstracted_code)

print(f"추상화된 코드가 {output_file}에 저장되었습니다.")

# Joern 스크립트 생성
joern_script_content = f"""
importCode("{file_name}")

// 이미 추적한 줄 번호를 저장하여 중복 방지
val visitedLines = scala.collection.mutable.Set[Int]()
val result = scala.collection.mutable.ListBuffer[Int]()

// 특정 변수의 backward slicing 수행
def traceBackwardSlicing(variableName: String, methodName: String): Unit = {{
  println("=== 함수 '" + methodName + "'의 변수 '" + variableName + "' backward slicing ===")

  // 변수 사용 위치 추적 (methodName 내부로 제한)
  val variableUses = cpg.identifier.nameExact(variableName).where(_.method.nameExact(methodName))
  variableUses.foreach {{ useNode =>
    // 데이터 의존성 추적 (ddgIn)
    val dataDependencies = useNode.ddgIn
    dataDependencies.foreach {{ node =>
      val lineNumber = node.lineNumber.getOrElse(-1)
      val nodeCode = node.code
      if (!visitedLines.contains(lineNumber) && lineNumber != -1 &&
          nodeCode != null) {{
        visitedLines.add(lineNumber)
        result += lineNumber
        println(s"데이터 의존성 있는 코드: ${{node.code}} (줄 번호: $lineNumber)")
      }}
    }}

    // 변수 사용이 포함된 모든 호출을 추적
    val callsUsingVariable = useNode.inCall
    callsUsingVariable.foreach {{ callNode =>
      val lineNumber = callNode.lineNumber.getOrElse(-1)
      val nodeCode = callNode.code
      if (!visitedLines.contains(lineNumber) && lineNumber != -1 &&
          nodeCode != null && nodeCode.contains(variableName)) {{
        visitedLines.add(lineNumber)
        result += lineNumber
        println(s"변수를 사용하는 호출: ${{callNode.code}} (줄 번호: $lineNumber)")
      }}
    }}

    // 제어 흐름 상의 관련성 추적 (cfgIn)
    val controlDependencies = useNode.cfgIn
    controlDependencies.foreach {{ node =>
      val lineNumber = node.lineNumber.getOrElse(-1)
      val nodeCode = node.code
      if (!visitedLines.contains(lineNumber) && lineNumber != -1 &&
          nodeCode != null && nodeCode.contains(variableName)) {{
        visitedLines.add(lineNumber)
        result += lineNumber
        println(s"제어 흐름 관련 코드: ${{node.code}} (줄 번호: $lineNumber)")
      }}
    }}
  }}
}}

// Main 함수 작업
val callsInMain = cpg.method.name("main").call
val funcCalls = callsInMain.name("{func_name}")

if (!funcCalls.isEmpty) {{
  println("Main 함수에서 취약한 함수 발견!")
  funcCalls.argument.foreach {{ arg =>
    val variableName = arg.code
    println(s"파라미터: ${{variableName}}")
    traceBackwardSlicing(variableName, "main")
  }}

  // 함수 호출 추적
  funcCalls.foreach {{ callNode =>
    callNode.inAssignment.foreach {{ assignNode =>
      val assignedVariable = assignNode.target.code
      println(s"함수 결과가 변수 '${{assignedVariable}}'에 할당됨")
      traceBackwardSlicing(assignedVariable, "main")
    }}
  }}
}}

// 지역 함수 작업 (vulnerable_function)
val methods = cpg.method.name("{func_name}")
if (methods.size > 0) {{
  methods.foreach {{ method =>
    println(s"=== 함수 '${{method.name}}' 작업 ===")
    method.parameter.foreach {{ param =>
      val paramName = param.name
      println(s"파라미터: ${{paramName}}")
      traceBackwardSlicing(paramName, method.name)
    }}

    // 함수 내 데이터 의존성 추적 (strcpy와 관련된 변수만 포함)
    method.call.name("strcpy").argument.foreach {{ arg =>
      val variableName = arg.code
      println(s"관련 변수: ${{variableName}}")
      traceBackwardSlicing(variableName, method.name)
    }}
  }}
}}

// JSON 형식으로 결과 출력
println(ujson.write(result.toList))
"""


# Joern 스크립트 저장
script_path = "/content/trace.sc"
with open(script_path, "w") as script_file:
    script_file.write(joern_script_content)


# Joern 실행
result = subprocess.run(
    ["/content/joern-cli/joern", "--script", script_path],
    stdout=subprocess.PIPE,
    stderr=subprocess.PIPE,
    text=True
)



# Joern 실행 결과 출력 및 CSV 생성
if result.returncode == 0:
    print("Joern 실행 성공!")
    output = result.stdout.strip()
    print("=== 분석 결과 ===")
    print(output)
    


    try:
        # JSON 결과 추출 및 Python 변수로 저장
        lines = json.loads(output.splitlines()[-1])  # 마지막 JSON 출력 파싱
        lines.sort()  # 리스트를 오름차순으로 정렬
        print(f"Python 변수로 저장된 라인 번호 (오름차순): {lines}")
        
        # 파일 읽어서 특정 라인 번호의 소스 코드 추출
        with open(file_name, "r") as file:
            file_content = file.readlines()  # 원본 파일 내용
        with open(output_file, "r") as tmp_file:
            tmp_file_content = tmp_file.readlines()  # 추상화된 파일 내용

        print("=== 특정 라인의 소스 코드 ===")
        rows = []  # CSV에 저장할 데이터를 담을 리스트
        for line_number in lines:
            if 0 < line_number <= len(file_content):  # 유효한 라인 번호인지 확인
                origin_code = file_content[line_number - 1].strip()  # 원본 코드
                abs_code = tmp_file_content[line_number - 1].strip() if line_number <= len(tmp_file_content) else ""  # 추상화된 코드
                print(f"{line_number}: origin: {origin_code}, abs_code: {abs_code}")
                rows.append({"origin": origin_code, "abs_code": abs_code})
        
        # 기존 CSV 읽어서 중복 확인
        existing_hashes = set()
        if os.path.exists(csv_file_name):
            with open(csv_file_name, "r", encoding="utf-8") as csv_file:
                csv_reader = csv.DictReader(csv_file)
                for row in csv_reader:
                    if "hash" in row:  # 해시 값이 있는지 확인
                        existing_hashes.add(row["hash"])

        # CSV 파일 업데이트
        with open(csv_file_name, "a", newline="", encoding="utf-8") as csv_file:  # 'a' 모드로 파일 열기
            csv_writer = csv.DictWriter(csv_file, fieldnames=["file_name", "func_name", "origin", "abs_code", "hash"])
            
            # 헤더 작성 (처음 생성될 때만 추가)
            if os.stat(csv_file_name).st_size == 0:
                csv_writer.writeheader()
            
            for row in rows:
                # 추상화된 코드(abs_code)를 정규화한 후 해시 생성
                normalized_abs_code = re.sub(r'\s+', '', row["abs_code"])
                hash_value = hashlib.sha256(normalized_abs_code.encode("utf-8")).hexdigest()
                if hash_value not in existing_hashes:  # 중복 확인
                    csv_writer.writerow({
                        "file_name": file_name,
                        "func_name": func_name,
                        "origin": row["origin"],
                        "abs_code": normalized_abs_code,
                        "hash": hash_value
                    })
                    existing_hashes.add(hash_value)  # 새로운 해시 추가
        
        print(f"CSV 파일 업데이트 완료: {csv_file_name}")

    except json.JSONDecodeError as e:
        print("JSON 디코딩 오류:", e)
else:
    print("Joern 실행 실패!")
    print("=== 오류 메시지 ===")
    print(result.stderr.strip())

 

이제 검증만 더 고도화 하면 될 것 같다. -> 이미지와 다르게 공백도 처리하게 됐다