Theory Simple_Memory

section Simple Memory Model
theory Simple_Memory
imports "../../lib/LLVM_Integer" "../../lib/LLVM_Double"  "../../lib/MM/MMonad" 
begin

  text Here, we combine a model of LLVM values, with our generic block-based memory model

  datatype llvm_ptr = is_null: PTR_NULL | is_addr: PTR_ADDR (the_addr: addr)
  hide_const (open) llvm_ptr.is_null llvm_ptr.is_addr llvm_ptr.the_addr

    
  lifting_update memory.lifting
  lifting_forget memory.lifting
        
  subsection LLVM Values

  datatype llvm_val = 
    is_struct: LL_STRUCT (the_fields: "llvm_val list") 
  | is_int: LL_INT (the_int: lint) 
  | is_double: LL_DOUBLE (the_double: double) (* 
      TODO: Similar to lint, we could encode different floating-point layouts here, 
        and restrict the code-generator to only accept the ones supported by LLVM.
    *)
  | is_ptr: LL_PTR (the_ptr: llvm_ptr)
  hide_const (open) 
    llvm_val.is_struct llvm_val.the_fields
    llvm_val.is_int llvm_val.the_int
    llvm_val.is_double llvm_val.the_double
    llvm_val.is_ptr llvm_val.the_ptr
  
  
  datatype llvm_struct = 
    is_struct: VS_STRUCT (the_fields: "llvm_struct list") 
  | is_int: VS_INT (the_width: nat)
  | is_double: VS_DOUBLE
  | is_ptr: VS_PTR 
  hide_const (open) 
    llvm_struct.is_struct llvm_struct.the_fields
    llvm_struct.is_ptr
    llvm_struct.is_double
    llvm_struct.is_int llvm_struct.the_width
  
  
  
  fun llvm_struct_of_val where
    "llvm_struct_of_val (LL_STRUCT vs) = VS_STRUCT (map llvm_struct_of_val vs)"
  | "llvm_struct_of_val (LL_INT i) = VS_INT (width i)"
  | "llvm_struct_of_val (LL_DOUBLE _) = VS_DOUBLE"
  | "llvm_struct_of_val (LL_PTR _) = VS_PTR"

  fun llvm_zero_initializer where
    "llvm_zero_initializer (VS_STRUCT vss) = LL_STRUCT (map llvm_zero_initializer vss)"
  | "llvm_zero_initializer (VS_INT w) = LL_INT (lconst w 0)"
  | "llvm_zero_initializer (VS_DOUBLE) = LL_DOUBLE (double_of_word 0)"
  | "llvm_zero_initializer VS_PTR = LL_PTR PTR_NULL"
  
  lemma struct_of_llvm_zero_initializer[simp]: "llvm_struct_of_val (llvm_zero_initializer s) = s"
    apply (induction s) 
    apply (simp_all add: map_idI)
    done

  (*type_synonym llvm_memory = "llvm_val memory"
  translations (type) "llvm_memory" ↽ (type) "llvm_val memory"
  *)

  type_synonym 'a llM = "('a,llvm_val) M"
  translations
    (type) "'a llM"  (type) "('a, llvm_val) M"

    
  subsection Raw operations on values  
  context
    includes monad_syntax_M
  begin
    
  definition llvm_extract_addr :: "llvm_val  addr llM" where
    "llvm_extract_addr v  case v of LL_PTR (PTR_ADDR a)  return a | _  fail"

  definition llvm_extract_ptr :: "llvm_val  llvm_ptr llM" where
    "llvm_extract_ptr v  case v of LL_PTR p  return p | _  fail"
    
  definition llvm_extract_sint :: "llvm_val  int llM" where
    "llvm_extract_sint v  case v of LL_INT i  return (lint_to_sint i) | _  fail" 
        
  definition llvm_extract_unat :: "llvm_val  nat llM" where
    "llvm_extract_unat v  case v of LL_INT i  return (nat (lint_to_uint i)) | _  fail" 

  definition llvm_extract_value :: "llvm_val  nat  llvm_val llM" where 
  "llvm_extract_value v i  case v of 
    LL_STRUCT vs  doM {
      assert (i<length vs);
      return (vs!i)
    }
  | _  fail"
      
  definition llvm_insert_value :: "llvm_val  llvm_val  nat  llvm_val llM" where 
  "llvm_insert_value v x i  case v of 
    LL_STRUCT vs  doM {
      assert (i<length vs);
      assert (llvm_struct_of_val (vs!i) = llvm_struct_of_val x);
      return (LL_STRUCT (vs[i:=x]))
    }
  | _  fail"

    
  subsection Interface functions
  
  subsubsection Typed arguments
    
  (* TODO: redundancy with is_valid_addr! *)
  definition llvmt_check_addr :: "addr  unit llM" where "llvmt_check_addr a  doM { 
    Mvalid_addr a
  }"
    
  definition llvmt_load :: "addr  llvm_val llM" where "llvmt_load a  doM { 
    Mload a
  }"
  
  definition "llvmt_store x a  doM { 
    xorig  llvmt_load a; 
    assert llvm_struct_of_val x = llvm_struct_of_val xorig;
    Mstore a x
  }"
  
  definition "llvmt_alloc s n  doM {
    Mmalloc (replicate n (llvm_zero_initializer s))
  }"
  
  definition llvmt_free :: "nat  unit llM" where "llvmt_free b  doM {
    Mfree b
  }"  

  definition "llvmt_freep p  doM {
    assert llvm_ptr.is_addr p;
    let a = llvm_ptr.the_addr p;
  
    assert addr.index a=0;
    llvmt_free (addr.block a);
    return ()
  }"  

  definition "llvmt_allocp s n  doM {
    b  llvmt_alloc s n;
    return (PTR_ADDR (ADDR b 0))
  }"
    
  
  definition llvmt_check_ptr :: "llvm_ptr  unit llM" where "llvmt_check_ptr p  
    if llvm_ptr.is_null p then return ()
    else doM {
      let a = llvm_ptr.the_addr p;
      Mvalid_addr a ― ‹TODO: support 1-beyond-end pointers!
    }"
      
  definition "llvmt_ofs_ptr p ofs  doM {
    assert (llvm_ptr.is_addr p);
    let a = llvm_ptr.the_addr p;
    let b = addr.block a;
    let i = addr.index a;
    let i = i + ofs;
    let r = PTR_ADDR (ADDR b i);
    llvmt_check_ptr r;
    return r
  }"  
    
  definition "llvmt_check_ptrcmp p1 p2  
    if p1=PTR_NULL  p2=PTR_NULL then 
      return () 
    else doM {
      llvmt_check_ptr p1;
      llvmt_check_ptr p2
    }"
  
  definition "llvmt_ptr_eq p1 p2  doM {
    llvmt_check_ptrcmp p1 p2;
    return (p1 = p2)
  }"
  
  definition "llvmt_ptr_neq p1 p2  doM {
    llvmt_check_ptrcmp p1 p2;
    return (p1  p2)
  }"
  
  subsubsection Embedded arguments

  definition "llvm_load a  doM {
    a  llvm_extract_addr a;
    llvmt_load a
  }"
  
  definition "llvm_store x a  doM {
    a  llvm_extract_addr a;
    llvmt_store x a
  }"
  
  definition "llvm_alloc s n  doM {
    n  llvm_extract_unat n;
    p  llvmt_allocp s n;
    return (LL_PTR p)
  }"

  definition "llvm_extract_base_block a  case a of ADDR b i  if i=0 then return b else fail"
  
  definition "llvm_free p  doM {
    p  llvm_extract_ptr p;
    llvmt_freep p
  }"
  
  definition "llvm_ofs_ptr p ofs  doM {
    p  llvm_extract_ptr p;
    ofs  llvm_extract_sint ofs;
    r  llvmt_ofs_ptr p ofs;
    return (LL_PTR r)
  }"  
  
  
  definition "llvm_ptr_eq p1 p2  doM {
    p1  llvm_extract_ptr p1;
    p2  llvm_extract_ptr p2;
    r  llvmt_ptr_eq p1 p2;
    return (LL_INT (bool_to_lint r))
  }"  
  
  definition "llvm_ptr_neq p1 p2  doM {
    p1  llvm_extract_ptr p1;
    p2  llvm_extract_ptr p2;
    r  llvmt_ptr_neq p1 p2;
    return (LL_INT (bool_to_lint r))
  }"  
  
end
end