Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ endif()
#-------------------------------------------------------------------------------

add_subdirectory(cmake)
add_subdirectory(runtime)
add_subdirectory(frontend)
add_subdirectory(midend)
add_subdirectory(backend)
Expand Down
327 changes: 228 additions & 99 deletions frontend/Python/frontend.py

Large diffs are not rendered by default.

21 changes: 17 additions & 4 deletions frontend/Python/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,15 +242,14 @@ def displace_node(self, node: Op, newnode: Op):
self.node_table.pop(node.name)
self.node_table[newnode.name] = newnode


def displace_node_with_chain(self, node: Op, chain: list[Op]):
"""
Replaces an existing node with a chain of new nodes.
- The first node is taken to be the "head" of the chain, and all parents of the
current node will have this node as their child instead of `node`
- The last node is taken to be the "tail" of the chain, and all children of `node`
will have this node as their parent instead.

Args:
node (Op): The operation to be replaced.
chain (list[Op]): The a list of nodes to be inserted instead of Op
Expand Down Expand Up @@ -281,7 +280,7 @@ def displace_node_with_chain(self, node: Op, chain: list[Op]):
node._children.clear()

node_idx = self._body.index(node)
self._body = self.body[:node_idx] + chain + self.body[node_idx+1:]
self._body = self.body[:node_idx] + chain + self.body[node_idx + 1 :]

def init_op_group(self):
"""
Expand Down Expand Up @@ -385,6 +384,12 @@ def lower_to_top_level_ir(self):
np_type = np.dtype(np.uint16)
case "f32":
np_type = np.dtype(np.float32)
case "f64":
np_type = np.dtype(np.float64)
case "complex<f32>":
np_type = np.dtype(np.complex64)
case "complex<f64>":
np_type = np.dtype(np.complex128)
case _:
raise NotImplementedError(f"Unsupported dtype {dtype}")
self._output_memref.append(
Expand Down Expand Up @@ -421,7 +426,8 @@ def lower_to_llvm_ir(self):
pm.add("empty-tensor-to-alloc-tensor")
pm.add("convert-elementwise-to-linalg")
pm.add("one-shot-bufferize{bufferize-function-boundaries}")
pm.add("func.func(convert-linalg-to-affine-loops)")
pm.add("func.func(linalg-generalize-named-ops)")
pm.add("func.func(convert-linalg-to-loops)")
pm.add("affine-loop-fusion")
pm.add("func.func(affine-parallelize)")
pm.add("convert-scf-to-openmp")
Expand All @@ -430,6 +436,7 @@ def lower_to_llvm_ir(self):
pm.add("convert-vector-to-llvm")
pm.add("memref-expand")
pm.add("arith-expand")
pm.add("convert-complex-to-llvm")
pm.add("convert-arith-to-llvm")
pm.add("finalize-memref-to-llvm")
pm.add("convert-scf-to-cf")
Expand Down Expand Up @@ -529,8 +536,14 @@ def _str_to_mlir_dtype(self, dtype: str) -> ir.Type:
return ir.BF16Type.get()
case TensorDType.Float32:
return ir.F32Type.get()
case TensorDType.Float64:
return ir.F64Type.get()
case TensorDType.Bool:
return ir.IntegerType.get_signless(1)
case TensorDType.Complex64:
return ir.ComplexType.get(ir.F32Type.get())
case TensorDType.Complex128:
return ir.ComplexType.get(ir.F64Type.get())
case _:
raise NotImplementedError(f"Unsupported dtype {dtype}")

Expand Down
47 changes: 44 additions & 3 deletions frontend/Python/graph/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,12 @@ def __init__(self) -> None:
self._op_type = OpType.ReshapeType


class ViewDtypeOp(Op):
def __init__(self) -> None:
super().__init__()
self._op_type = OpType.ReshapeType


class EmbeddingOp(Op):
def __init__(self) -> None:
super().__init__()
Expand Down Expand Up @@ -584,10 +590,12 @@ def __init__(
tensor_meta: dict,
name: str = None,
) -> None:
super().__init__(name)
super().__init__()
if name is not None:
self._name = name
self.call_func_name = call_func_name
self.args = args
self._args_index = args_index
self._arguments = list(args)
self._args_index = list(args_index)
self.tensor_meta = tensor_meta
self._op_type = OpType.Unfusable

Expand Down Expand Up @@ -2057,6 +2065,28 @@ def __init__(self) -> None:
self._op_type = OpType.ReshapeType


class GeometricOp(Op):
"""
Geometric distribution sampling.
Implements aten.geometric.default / aten.geometric.out / aten.geometric_.default.
"""

def __init__(self) -> None:
super().__init__()
self._op_type = OpType.ReshapeType


class ExponentialOp(Op):
"""
Exponential distribution sampling.
Implements aten.exponential.default / aten.exponential.out / aten.exponential_.default.
"""

def __init__(self) -> None:
super().__init__()
self._op_type = OpType.ReshapeType


class SelectScatterOp(Op):
"""
Select scatter operation.
Expand Down Expand Up @@ -2813,3 +2843,14 @@ class UniformOp(Op):
def __init__(self) -> None:
super().__init__()
self._op_type = OpType.ReshapeType


class BernoulliOp(Op):
"""
Bernoulli sampling operation.
Implements aten.bernoulli.* variants (functional forms).
"""

def __init__(self) -> None:
super().__init__()
self._op_type = OpType.ElementwiseType
6 changes: 6 additions & 0 deletions frontend/Python/graph/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,9 @@
replace_matmul_with_onednn,
replace_matmul_with_onednn_selective,
)
from .rand_replace import (
replace_bernoulli_with_runtime_rng,
replace_exponential_with_runtime_rng,
replace_geometric_with_runtime_rng,
replace_rand_with_runtime_rng,
)
Loading