読者です 読者をやめる 読者になる 読者になる

レガシーコード生産ガイド

私に教えられることなら

Scheme(Gauche)で型推論を書いてみる

Hindley-Milnerの型推論アルゴリズムと、型クラスの実装に興味があるのですが、どこから手をつけていいのかさっぱりわかりません。とりあえず簡単に読めるコード無いかな、と探したところJavaScriptでのHM型推論実装を見つけました。

Type Inference in JavaScript

コードにはしっかりドキュメントとしてのコメントがついていて非常にわかりやすいです。ブラウザ上で見るとさらにわかりやすい。(この表示たまに見かけるけどめちゃくちゃ見やすくていいですね)

This is based on Robert Smallshire's Python code. Which is based on Andrew's Scala code. Which is based on Nikita Borisov's Perl code. Which is based on Luca Cardelli's Modula-2 code. Wow.

ということで更にそれを元にしてSchemeで書いてみました。コードは一番最後にあります。非破壊的コードに直しながらだと更に難しそうなので、簡単なメッセージパッシングスタイルでJSのオブジェクトを模倣しました。

とりあえず((lambda [f x] (f x)) (lambda [x] x) 123)の型をNumberと推論できたのですが、ここからカリー化、代数的データ型、パターンマッチング(コンストラクタ/デストラクタ)、型クラスまで持っていく道筋が全然わかりません。ウーム……

(追記) カリー化はできました(多分)

;; ref: https://brianmckenna.org/blog/js_type_inference

(use util.match)

(define (test expr result)
  (if result
      (format #t "ok ~S~%" expr)
      (begin
        (format #t " error ~S~%" expr)
        (exit))))

(define-macro (assert expr)
  `(test ',expr ,expr))


(define (map-get m k) (assoc-ref m k))
(define (map-set m k v) (acons k v m))
(define (map-zip m ks vs) (fold (lambda [k v m] (map-set m k v)) m ks vs))


(define (make-counter)
  (let ([id 0])
    (cons
     (lambda [] (set! id (+ id 1)))
     (lambda [] (set! id 0)))))

(define vcounter (make-counter))
(define make-variable-id (car vcounter))
(define clear-variable-id (cdr vcounter))

(define (Variable)
  (let ([id (make-variable-id)] [instance '()])
    (match-lambda*
     [('id) id]
     [('variable?) #t]
     [('basetype?) #f]
     [('types) '()]
     [('instance) instance]
     [('set-instance x) (set! instance x)]
     [('to-string)
      (if (null? instance)
          (format #f "'a~A" id)
          (instance 'to-string))])))

(define (BaseType name types)
  (define (normal-to-string name types)
    (if (null? types)
        name
        (format #f "~A ~A" name (map (lambda [type] (type 'to-string)) types))))
  (let* [(to-string-f normal-to-string)]
    (match-lambda*
     [('variable?) #f]
     [('basetype?) #t]
     [('name) name]
     [('types) types]
     [('to-string) (to-string-f name types)]
     [('set-to-string f) (set! to-string-f f)])))

(define (FunctionType . types)
  (define (tostr name types)
    (format #f "(~A)"
            (string-join
             (map (lambda [type] (type 'to-string)) types)
             " -> ")))
  (let* ([type (BaseType "Function" types)])
    (type 'set-to-string tostr)
    type))

(define (NumberType) (BaseType "Number" '()))
(define (BooleanType) (BaseType "Boolean" '()))
(define (ListType type) (BaseType "List" (list type)))


(define (prune t1)
  (if (and (t1 'variable?) (not (null? (t1 'instance))))
      (begin
        (t1 'set-instance (prune (t1 'instance)))
        (t1 'instance))
      t1))


(define (occurs-in-type t1 t2)
  (let* ([t2 (prune t2)])
    (cond
     [(eq? t1 t2) #t]
     [(t2 'basetype?) (occurs-in-types t1 (t2 'types))]
     [else #f])))


(define (occurs-in-types t1 types)
  (find (lambda [t2] (occurs-in-type t1 t2)) types))


(define (unify t1 t2)
  (let* ([t1 (prune t1)]
         [t2 (prune t2)])
    (cond
     [(t1 'variable?)
      (if (eq? t1 t2)
          (when (occurs-in-type t1 t2) (error "Recursive Unification"))
          (t1 'set-instance t2))]
     [(and (t1 'basetype?) (t2 'variable?)) (unify t2 t1)]
     [(and (t1 'basetype?) (t2 'basetype?))
      (if (not (and (equal? (t1 'name) (t2 'name))
                    (= (length (t1 'types)) (length (t2 'types)))))
          (error (format #f "~A is not ~A" (t1 'to-string) (t2 'to-string)))
          (map unify (t1 'types) (t2 'types)))]
     [else (error "Not unified")])))


(define (fresh type non-generics)
  (define (ret type mapping) (cons type mapping))
  (define (get-type ret) (car ret))
  (define (get-map  ret) (cdr ret))

  (define (fold-map types mapping)
    (let construct ([types types] [mapping mapping])
      (if (null? types)
          '()
          (let* ([walked (walk (car types) non-generics mapping)])
            (cons (get-type walked) (construct (cdr types) (get-map walked)))))))

  (define (walk type non-generics mapping)
    (let* ([type (prune type)])
      (if (type 'variable?)
          (cond
           [(occurs-in-types type non-generics) (ret type mapping)]
           [(map-get mapping (type 'id)) (ret (map-get mapping (type 'id)) mapping)]
           [else
            (let* ([new-var (Variable)])
              (ret new-var (map-set mapping (type 'id) new-var)))])
          (ret (BaseType (type 'name) (fold-map (type 'types) mapping))
               mapping))))

  (get-type (walk type non-generics '())))


(define (type-analyse expr env)
  (define (analyse expr env non-generics)
    (define (an-lambda args body )
      (let* ([types (map (lambda [_] (Variable)) args)]
             [env (map-zip env args types)]
             [non-generics (append types non-generics)]
             [result-type (analyse body env non-generics)]
             [ftype (append types (list result-type))])
        (apply FunctionType ftype)))
    (define (an-call f args)
      (let* ([types (map (lambda [arg] (analyse arg env non-generics)) args)]
             [result (list (Variable))]
             [len (length types)]
             [ftype (analyse f env non-generics)]
             [ftypes (ftype 'types)]
             [flen (length ftypes)])
        ;; currying
        (when (and (ftype 'basetype?) (not (= (+ len 1) flen)))
          (let loop ([i (- flen len 1)])
            (when (> i 0)
              (set! result (cons (Variable) result))
              (loop (- i 1)))))
        (unify ftype (apply FunctionType (append types result)))
        (if (> (length result) 1)
            (apply FunctionType result)
            (car result))))
    (define (an-identifier sym)
      (let* ([type (map-get env sym)])
        (if type
            (fresh type non-generics)
            (error (format #f "Undefined: ~S" sym)))))
    (match expr
      [('lambda args body) (an-lambda args body)]
      [(f . args) (an-call f args)]
      [expr
       (cond
        [(symbol? expr) (an-identifier expr)]
        [(boolean? expr) (BooleanType)]
        [(number? expr) (NumberType)]
        [else (error (format #f "Unknown type of: ~S~%" expr))])]))
  (clear-variable-id)
  (analyse expr env '()))

(define GLOBAL-ENV
  `((+ . ,(FunctionType (NumberType) (NumberType) (NumberType)))
    (= . ,(FunctionType (NumberType) (NumberType) (BooleanType)))))

(define (type-str expr) ((type-analyse expr GLOBAL-ENV) 'to-string))


;; tests
;; -----
(assert (equal? ((Variable) 'to-string) "'a1"))
(assert (equal? ((FunctionType (BooleanType) (NumberType)) 'to-string)
                "(Boolean -> Number)"))

;; unify
(let ([a (Variable)])
  (unify a
         (FunctionType (NumberType) (BooleanType)))
  (assert (equal? (a 'to-string) "(Number -> Boolean)")))

(let ([list-type (Variable)])
  (unify (FunctionType (ListType list-type) (NumberType))
         (FunctionType (ListType (NumberType)) (NumberType)))
  (assert (equal? (list-type 'to-string) "Number")))

(let* ([t (FunctionType (ListType (NumberType)) (Variable))]
       [old (t 'to-string)]
       [new ((fresh t '()) 'to-string)])
  (assert (not (equal? old new))))


(assert (equal? (type-str 1) "Number"))
(assert (equal? (type-str #f) "Boolean"))
(assert (equal? (type-str '(lambda [x] 1)) "('a1 -> Number)"))
(assert (equal? (type-str '(lambda [x] x)) "('a1 -> 'a1)"))
(assert (equal? (type-str '((lambda [x] x) 1)) "Number"))
(assert
 (equal?
  (type-str '((lambda [f x] (f x)) (lambda [x] x) 123))
  "Number"))

;; currying
(define (fargc expr) (length ((type-analyse expr '()) 'types)))

(assert (= (fargc '(lambda [x y z] x))         4)) ;; (a -> b -> c -> a)
(assert (= (fargc '((lambda [x y z] x)))       4)) ;; (a -> b -> c -> a)
(assert (= (fargc '((lambda [x y z] x) 1))     3)) ;; (b -> c -> Number)
(assert (= (fargc '((lambda [x y z] x) 1 2))   2)) ;; (c -> Number)

;; global env and currying
(assert
 (equal?
  (type-str '(lambda [x y z] (= (+ x y) z)))
  "(Number -> Number -> Number -> Boolean)"))

(assert
 (equal?
  (type-str '((lambda [x y z] (= (+ x y) z)) 1))
  "(Number -> Number -> Boolean)"))

(assert
 (equal?
  (type-str '((lambda [x y z] (= (+ x y) z)) 1 2))
  "(Number -> Boolean)"))

(assert
 (equal?
  (type-str '((lambda [x y z] (= (+ x y) z)) 1 2 3))
  "Boolean"))
広告を非表示にする